get_mag_data.py 2.4 KB
Newer Older
1
import json
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
2
3

import dgl
4
import numpy as np
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5
import torch as th
6
7
8
from ogb.nodeproppred import DglNodePropPredDataset

# Load OGB-MAG.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
dataset = DglNodePropPredDataset(name="ogbn-mag")
10
11
12
13
14
hg_orig, labels = dataset[0]
subgs = {}
for etype in hg_orig.canonical_etypes:
    u, v = hg_orig.all_edges(etype=etype)
    subgs[etype] = (u, v)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
15
    subgs[(etype[2], "rev-" + etype[1], etype[0])] = (v, u)
16
hg = dgl.heterograph(subgs)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
17
hg.nodes["paper"].data["feat"] = hg_orig.nodes["paper"].data["feat"]
18
19

split_idx = dataset.get_idx_split()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
20
21
22
23
train_idx = split_idx["train"]["paper"]
val_idx = split_idx["valid"]["paper"]
test_idx = split_idx["test"]["paper"]
paper_labels = labels["paper"].squeeze()
24

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
25
train_mask = th.zeros((hg.number_of_nodes("paper"),), dtype=th.bool)
26
train_mask[train_idx] = True
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
27
val_mask = th.zeros((hg.number_of_nodes("paper"),), dtype=th.bool)
28
val_mask[val_idx] = True
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
29
test_mask = th.zeros((hg.number_of_nodes("paper"),), dtype=th.bool)
30
test_mask[test_idx] = True
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
31
32
33
34
hg.nodes["paper"].data["train_mask"] = train_mask
hg.nodes["paper"].data["val_mask"] = val_mask
hg.nodes["paper"].data["test_mask"] = test_mask
hg.nodes["paper"].data["labels"] = paper_labels
35

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
36
with open("outputs/mag.json") as json_file:
37
38
    metadata = json.load(json_file)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
39
40
for part_id in range(metadata["num_parts"]):
    subg = dgl.load_graphs("outputs/part{}/graph.dgl".format(part_id))[0][0]
41
42
43

    node_data = {}
    for ntype in hg.ntypes:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
44
45
46
47
48
        local_node_idx = th.logical_and(
            subg.ndata["inner_node"].bool(),
            subg.ndata[dgl.NTYPE] == hg.get_ntype_id(ntype),
        )
        local_nodes = subg.ndata["orig_id"][local_node_idx].numpy()
49
        for name in hg.nodes[ntype].data:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
50
51
52
53
54
55
56
57
            node_data[ntype + "/" + name] = hg.nodes[ntype].data[name][
                local_nodes
            ]
    print("node features:", node_data.keys())
    dgl.data.utils.save_tensors(
        "outputs/" + metadata["part-{}".format(part_id)]["node_feats"],
        node_data,
    )
58
59
60

    edge_data = {}
    for etype in hg.etypes:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
61
62
63
        local_edges = subg.edata["orig_id"][
            subg.edata[dgl.ETYPE] == hg.get_etype_id(etype)
        ]
64
        for name in hg.edges[etype].data:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
65
66
67
68
69
70
71
72
            edge_data[etype + "/" + name] = hg.edges[etype].data[name][
                local_edges
            ]
    print("edge features:", edge_data.keys())
    dgl.data.utils.save_tensors(
        "outputs/" + metadata["part-{}".format(part_id)]["edge_feats"],
        edge_data,
    )