preprocess.py 5.77 KB
Newer Older
1
2
3
import argparse
import os

4
import numpy as np
5
import ogb
6
import torch
7
8
9
import tqdm
from ogb.lsc import MAG240MDataset

10
11
12
13
import dgl
import dgl.function as fn

parser = argparse.ArgumentParser()
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
parser.add_argument(
    "--rootdir",
    type=str,
    default=".",
    help="Directory to download the OGB dataset.",
)
parser.add_argument(
    "--author-output-path", type=str, help="Path to store the author features."
)
parser.add_argument(
    "--inst-output-path",
    type=str,
    help="Path to store the institution features.",
)
parser.add_argument(
    "--graph-output-path", type=str, help="Path to store the graph."
)
parser.add_argument(
    "--graph-format",
    type=str,
    default="csc",
    help="Graph format (coo, csr or csc).",
)
parser.add_argument(
    "--graph-as-homogeneous",
    action="store_true",
    help="Store the graph as DGL homogeneous graph.",
)
parser.add_argument(
    "--full-output-path",
    type=str,
    help="Path to store features of all nodes.  Effective only when graph is homogeneous.",
)
47
48
args = parser.parse_args()

49
print("Building graph")
50
dataset = MAG240MDataset(root=args.rootdir)
51
52
53
ei_writes = dataset.edge_index("author", "writes", "paper")
ei_cites = dataset.edge_index("paper", "paper")
ei_affiliated = dataset.edge_index("author", "institution")
54
55
56
57
58
59

# We sort the nodes starting with the papers, then the authors, then the institutions.
author_offset = 0
inst_offset = author_offset + dataset.num_authors
paper_offset = inst_offset + dataset.num_institutions

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
g = dgl.heterograph(
    {
        ("author", "write", "paper"): (ei_writes[0], ei_writes[1]),
        ("paper", "write-by", "author"): (ei_writes[1], ei_writes[0]),
        ("author", "affiliate-with", "institution"): (
            ei_affiliated[0],
            ei_affiliated[1],
        ),
        ("institution", "affiliate", "author"): (
            ei_affiliated[1],
            ei_affiliated[0],
        ),
        ("paper", "cite", "paper"): (
            np.concatenate([ei_cites[0], ei_cites[1]]),
            np.concatenate([ei_cites[1], ei_cites[0]]),
        ),
    }
)
78
79

paper_feat = dataset.paper_feat
80
81
82
83
84
85
86
87
88
89
90
91
author_feat = np.memmap(
    args.author_output_path,
    mode="w+",
    dtype="float16",
    shape=(dataset.num_authors, dataset.num_paper_features),
)
inst_feat = np.memmap(
    args.inst_output_path,
    mode="w+",
    dtype="float16",
    shape=(dataset.num_institutions, dataset.num_paper_features),
)
92
93
94
95
96

# Iteratively process author features along the feature dimension.
BLOCK_COLS = 16
with tqdm.trange(0, dataset.num_paper_features, BLOCK_COLS) as tq:
    for start in tq:
97
98
99
100
        tq.set_postfix_str("Reading paper features...")
        g.nodes["paper"].data["x"] = torch.FloatTensor(
            paper_feat[:, start : start + BLOCK_COLS].astype("float32")
        )
101
        # Compute author features...
102
103
        tq.set_postfix_str("Computing author features...")
        g.update_all(fn.copy_u("x", "m"), fn.mean("m", "x"), etype="write-by")
104
        # Then institution features...
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        tq.set_postfix_str("Computing institution features...")
        g.update_all(
            fn.copy_u("x", "m"), fn.mean("m", "x"), etype="affiliate-with"
        )
        tq.set_postfix_str("Writing author features...")
        author_feat[:, start : start + BLOCK_COLS] = (
            g.nodes["author"].data["x"].numpy().astype("float16")
        )
        tq.set_postfix_str("Writing institution features...")
        inst_feat[:, start : start + BLOCK_COLS] = (
            g.nodes["institution"].data["x"].numpy().astype("float16")
        )
        del g.nodes["paper"].data["x"]
        del g.nodes["author"].data["x"]
        del g.nodes["institution"].data["x"]
120
121
122
123
124
125
126
127
128
129
130
author_feat.flush()
inst_feat.flush()

# Convert to homogeneous if needed.  (The RGAT baseline needs homogeneous graph)
if args.graph_as_homogeneous:
    # Process graph
    g = dgl.to_homogeneous(g)
    # DGL ensures that nodes with the same type are put together with the order preserved.
    # DGL also ensures that the node types are sorted in ascending order.
    assert torch.equal(
        g.ndata[dgl.NTYPE],
131
132
133
134
135
136
137
138
        torch.cat(
            [
                torch.full((dataset.num_authors,), 0),
                torch.full((dataset.num_institutions,), 1),
                torch.full((dataset.num_papers,), 2),
            ]
        ),
    )
139
140
    assert torch.equal(
        g.ndata[dgl.NID],
141
142
143
144
145
146
147
148
149
        torch.cat(
            [
                torch.arange(dataset.num_authors),
                torch.arange(dataset.num_institutions),
                torch.arange(dataset.num_papers),
            ]
        ),
    )
    g.edata["etype"] = g.edata[dgl.ETYPE].byte()
150
151
152
153
154
155
    del g.edata[dgl.ETYPE]
    del g.ndata[dgl.NTYPE]
    del g.ndata[dgl.NID]

    # Process feature
    full_feat = np.memmap(
156
157
158
159
160
161
162
163
        args.full_output_path,
        mode="w+",
        dtype="float16",
        shape=(
            dataset.num_authors + dataset.num_institutions + dataset.num_papers,
            dataset.num_paper_features,
        ),
    )
164
165
166
    BLOCK_ROWS = 100000
    for start in tqdm.trange(0, dataset.num_authors, BLOCK_ROWS):
        end = min(dataset.num_authors, start + BLOCK_ROWS)
167
168
169
        full_feat[author_offset + start : author_offset + end] = author_feat[
            start:end
        ]
170
171
    for start in tqdm.trange(0, dataset.num_institutions, BLOCK_ROWS):
        end = min(dataset.num_institutions, start + BLOCK_ROWS)
172
173
174
        full_feat[inst_offset + start : inst_offset + end] = inst_feat[
            start:end
        ]
175
176
    for start in tqdm.trange(0, dataset.num_papers, BLOCK_ROWS):
        end = min(dataset.num_papers, start + BLOCK_ROWS)
177
178
179
180
        full_feat[paper_offset + start : paper_offset + end] = paper_feat[
            start:end
        ]

181
182
183
# Convert the graph to the given format and save.  (The RGAT baseline needs CSC graph)
g = g.formats(args.graph_format)
dgl.save_graphs(args.graph_output_path, g)