entity_utils.py 2.09 KB
Newer Older
Mufei Li's avatar
Mufei Li committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import dgl
import torch as th

from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset

def load_data(data_name, get_norm=False, inv_target=False):
    if data_name == 'aifb':
        dataset = AIFBDataset()
    elif data_name == 'mutag':
        dataset = MUTAGDataset()
    elif data_name == 'bgs':
        dataset = BGSDataset()
    else:
        dataset = AMDataset()

    # Load hetero-graph
    hg = dataset[0]

    num_rels = len(hg.canonical_etypes)
    category = dataset.predict_category
    num_classes = dataset.num_classes
    labels = hg.nodes[category].data.pop('labels')
    train_mask = hg.nodes[category].data.pop('train_mask')
    test_mask = hg.nodes[category].data.pop('test_mask')
    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()

    if get_norm:
        # Calculate normalization weight for each edge,
        # 1. / d, d is the degree of the destination node
        for cetype in hg.canonical_etypes:
            hg.edges[cetype].data['norm'] = dgl.norm_by_dst(hg, cetype).unsqueeze(1)
        edata = ['norm']
    else:
        edata = None

    # get target category id
    category_id = hg.ntypes.index(category)

    g = dgl.to_homogeneous(hg, edata=edata)
41
    # Rename the fields as they can be changed by for example DataLoader
Mufei Li's avatar
Mufei Li committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    g.ndata['ntype'] = g.ndata.pop(dgl.NTYPE)
    g.ndata['type_id'] = g.ndata.pop(dgl.NID)
    node_ids = th.arange(g.num_nodes())

    # find out the target node ids in g
    loc = (g.ndata['ntype'] == category_id)
    target_idx = node_ids[loc]

    if inv_target:
        # Map global node IDs to type-specific node IDs. This is required for
        # looking up type-specific labels in a minibatch
        inv_target = th.empty((g.num_nodes(),), dtype=th.int64)
        inv_target[target_idx] = th.arange(0, target_idx.shape[0],
                                           dtype=inv_target.dtype)
        return g, num_rels, num_classes, labels, train_idx, test_idx, target_idx, inv_target
    else:
        return g, num_rels, num_classes, labels, train_idx, test_idx, target_idx