entity_utils.py 2.21 KB
Newer Older
Mufei Li's avatar
Mufei Li committed
1
2
3
import dgl
import torch as th

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
5
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset

Mufei Li's avatar
Mufei Li committed
6
7

def load_data(data_name, get_norm=False, inv_target=False):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
8
    if data_name == "aifb":
Mufei Li's avatar
Mufei Li committed
9
        dataset = AIFBDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
10
    elif data_name == "mutag":
Mufei Li's avatar
Mufei Li committed
11
        dataset = MUTAGDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
12
    elif data_name == "bgs":
Mufei Li's avatar
Mufei Li committed
13
14
15
16
17
18
19
20
21
22
        dataset = BGSDataset()
    else:
        dataset = AMDataset()

    # Load hetero-graph
    hg = dataset[0]

    num_rels = len(hg.canonical_etypes)
    category = dataset.predict_category
    num_classes = dataset.num_classes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
23
24
25
    labels = hg.nodes[category].data.pop("labels")
    train_mask = hg.nodes[category].data.pop("train_mask")
    test_mask = hg.nodes[category].data.pop("test_mask")
Mufei Li's avatar
Mufei Li committed
26
27
28
29
30
31
32
    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()

    if get_norm:
        # Calculate normalization weight for each edge,
        # 1. / d, d is the degree of the destination node
        for cetype in hg.canonical_etypes:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
33
34
35
36
            hg.edges[cetype].data["norm"] = dgl.norm_by_dst(
                hg, cetype
            ).unsqueeze(1)
        edata = ["norm"]
Mufei Li's avatar
Mufei Li committed
37
38
39
40
41
42
43
    else:
        edata = None

    # get target category id
    category_id = hg.ntypes.index(category)

    g = dgl.to_homogeneous(hg, edata=edata)
44
    # Rename the fields as they can be changed by for example DataLoader
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
45
46
    g.ndata["ntype"] = g.ndata.pop(dgl.NTYPE)
    g.ndata["type_id"] = g.ndata.pop(dgl.NID)
Mufei Li's avatar
Mufei Li committed
47
48
49
    node_ids = th.arange(g.num_nodes())

    # find out the target node ids in g
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
50
    loc = g.ndata["ntype"] == category_id
Mufei Li's avatar
Mufei Li committed
51
52
53
54
55
56
    target_idx = node_ids[loc]

    if inv_target:
        # Map global node IDs to type-specific node IDs. This is required for
        # looking up type-specific labels in a minibatch
        inv_target = th.empty((g.num_nodes(),), dtype=th.int64)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
57
58
59
60
61
62
63
64
65
66
67
68
69
        inv_target[target_idx] = th.arange(
            0, target_idx.shape[0], dtype=inv_target.dtype
        )
        return (
            g,
            num_rels,
            num_classes,
            labels,
            train_idx,
            test_idx,
            target_idx,
            inv_target,
        )
Mufei Li's avatar
Mufei Li committed
70
71
    else:
        return g, num_rels, num_classes, labels, train_idx, test_idx, target_idx