data_loader.py 3.17 KB
Newer Older
1
2
import collections

KounianhuaDu's avatar
KounianhuaDu committed
3
4
5
import dgl
from dgl.data import PPIDataset

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
7
from torch.utils.data import DataLoader, Dataset

8
9
10
11
# implement the collate_fn for dgl graph data class
PPIBatch = collections.namedtuple("PPIBatch", ["graph", "label"])


KounianhuaDu's avatar
KounianhuaDu committed
12
13
14
def batcher(device):
    def batcher_dev(batch):
        batch_graphs = dgl.batch(batch)
15
16
17
18
        return PPIBatch(
            graph=batch_graphs, label=batch_graphs.ndata["label"].to(device)
        )

KounianhuaDu's avatar
KounianhuaDu committed
19
20
    return batcher_dev

21
22
23
24
25
26
27
28

# add a fresh "self-loop" edge type to the untyped PPI dataset and prepare train, val, test loaders
def load_PPI(batch_size=1, device="cpu"):
    train_set = PPIDataset(mode="train")
    valid_set = PPIDataset(mode="valid")
    test_set = PPIDataset(mode="test")
    # for each graph, add self-loops as a new relation type
    # here we reconstruct the graph since the schema of a heterograph cannot be changed once constructed
KounianhuaDu's avatar
KounianhuaDu committed
29
    for i in range(len(train_set)):
30
31
32
33
34
35
36
37
38
39
40
41
42
        g = dgl.heterograph(
            {
                ("_N", "_E", "_N"): train_set[i].edges(),
                ("_N", "self", "_N"): (
                    train_set[i].nodes(),
                    train_set[i].nodes(),
                ),
            }
        )
        g.ndata["label"] = train_set[i].ndata["label"]
        g.ndata["feat"] = train_set[i].ndata["feat"]
        g.ndata["_ID"] = train_set[i].ndata["_ID"]
        g.edges["_E"].data["_ID"] = train_set[i].edata["_ID"]
KounianhuaDu's avatar
KounianhuaDu committed
43
44
        train_set.graphs[i] = g
    for i in range(len(valid_set)):
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        g = dgl.heterograph(
            {
                ("_N", "_E", "_N"): valid_set[i].edges(),
                ("_N", "self", "_N"): (
                    valid_set[i].nodes(),
                    valid_set[i].nodes(),
                ),
            }
        )
        g.ndata["label"] = valid_set[i].ndata["label"]
        g.ndata["feat"] = valid_set[i].ndata["feat"]
        g.ndata["_ID"] = valid_set[i].ndata["_ID"]
        g.edges["_E"].data["_ID"] = valid_set[i].edata["_ID"]
        valid_set.graphs[i] = g
KounianhuaDu's avatar
KounianhuaDu committed
59
    for i in range(len(test_set)):
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        g = dgl.heterograph(
            {
                ("_N", "_E", "_N"): test_set[i].edges(),
                ("_N", "self", "_N"): (
                    test_set[i].nodes(),
                    test_set[i].nodes(),
                ),
            }
        )
        g.ndata["label"] = test_set[i].ndata["label"]
        g.ndata["feat"] = test_set[i].ndata["feat"]
        g.ndata["_ID"] = test_set[i].ndata["_ID"]
        g.edges["_E"].data["_ID"] = test_set[i].edata["_ID"]
        test_set.graphs[i] = g
KounianhuaDu's avatar
KounianhuaDu committed
74
75

    etypes = train_set[0].etypes
76
77
    in_size = train_set[0].ndata["feat"].shape[1]
    out_size = train_set[0].ndata["label"].shape[1]
KounianhuaDu's avatar
KounianhuaDu committed
78

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    # prepare train, valid, and test dataloaders
    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        collate_fn=batcher(device),
        shuffle=True,
    )
    valid_loader = DataLoader(
        valid_set,
        batch_size=batch_size,
        collate_fn=batcher(device),
        shuffle=True,
    )
    test_loader = DataLoader(
        test_set,
        batch_size=batch_size,
        collate_fn=batcher(device),
        shuffle=True,
    )
KounianhuaDu's avatar
KounianhuaDu committed
98
    return train_loader, valid_loader, test_loader, etypes, in_size, out_size