"vscode:/vscode.git/clone" did not exist on "066e8a4ef0e9728cb8744944155c6da815c3d8a0"
dataloader.py 876 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch.utils.data
from torch.utils.data.dataloader import DataLoader
import dgl
import numpy as np


def collate_fn(batch):
    """
    collate_fn for dataset batching
    transform ndata to tensor (in gpu is available)
    """
    graphs, labels = map(list, zip(*batch))

    # batch graphs and cast to PyTorch tensor
    for graph in graphs:
        for (key, value) in graph.ndata.items():
            graph.ndata[key] = value.float()
    batched_graphs = dgl.batch(graphs)

    # cast to PyTorch tensor
    batched_labels = torch.LongTensor(np.array(labels))

    return batched_graphs, batched_labels


class GraphDataLoader(DataLoader):
    def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs):
        super(GraphDataLoader, self).__init__(dataset, batch_size, shuffle,
                                              collate_fn=collate_fn, **kwargs)