data_loader.py 2.77 KB
Newer Older
KounianhuaDu's avatar
KounianhuaDu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from torch.utils.data import Dataset, DataLoader
import dgl
from dgl.data import PPIDataset
import collections

#implement the collate_fn for dgl graph data class
PPIBatch = collections.namedtuple('PPIBatch', ['graph', 'label'])
def batcher(device):
    def batcher_dev(batch):
        batch_graphs = dgl.batch(batch)
        return PPIBatch(graph=batch_graphs,
                        label=batch_graphs.ndata['label'].to(device))
    return batcher_dev

#add a fresh "self-loop" edge type to the untyped PPI dataset and prepare train, val, test loaders
def load_PPI(batch_size=1, device='cpu'):
    train_set = PPIDataset(mode='train')
    valid_set = PPIDataset(mode='valid')
    test_set = PPIDataset(mode='test')
    #for each graph, add self-loops as a new relation type
    #here we reconstruct the graph since the schema of a heterograph cannot be changed once constructed
    for i in range(len(train_set)):
        g = dgl.heterograph({
            ('_N','_E','_N'): train_set[i].edges(),
            ('_N', 'self', '_N'): (train_set[i].nodes(), train_set[i].nodes())
        })
        g.ndata['label'] = train_set[i].ndata['label']
        g.ndata['feat'] = train_set[i].ndata['feat']
        g.ndata['_ID'] = train_set[i].ndata['_ID']
        g.edges['_E'].data['_ID'] = train_set[i].edata['_ID']
        train_set.graphs[i] = g
    for i in range(len(valid_set)):
        g = dgl.heterograph({
            ('_N','_E','_N'): valid_set[i].edges(),
            ('_N', 'self', '_N'): (valid_set[i].nodes(), valid_set[i].nodes())
        })
        g.ndata['label'] = valid_set[i].ndata['label']
        g.ndata['feat'] = valid_set[i].ndata['feat']
        g.ndata['_ID'] = valid_set[i].ndata['_ID']
        g.edges['_E'].data['_ID'] = valid_set[i].edata['_ID']
        valid_set.graphs[i] = g    
    for i in range(len(test_set)):
        g = dgl.heterograph({
            ('_N','_E','_N'): test_set[i].edges(),
            ('_N', 'self', '_N'): (test_set[i].nodes(), test_set[i].nodes())
        })
        g.ndata['label'] = test_set[i].ndata['label']
        g.ndata['feat'] = test_set[i].ndata['feat']
        g.ndata['_ID'] = test_set[i].ndata['_ID']
        g.edges['_E'].data['_ID'] = test_set[i].edata['_ID']
        test_set.graphs[i] = g  

    etypes = train_set[0].etypes
    in_size = train_set[0].ndata['feat'].shape[1]
    out_size = train_set[0].ndata['label'].shape[1]

    #prepare train, valid, and test dataloaders
    train_loader = DataLoader(train_set, batch_size=batch_size, collate_fn=batcher(device), shuffle=True)
    valid_loader = DataLoader(valid_set, batch_size=batch_size, collate_fn=batcher(device), shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=batcher(device), shuffle=True)
    return train_loader, valid_loader, test_loader, etypes, in_size, out_size