dataloader.py 2.34 KB
Newer Older
Chen Sirui's avatar
Chen Sirui committed
1
import copy
2
import os
Chen Sirui's avatar
Chen Sirui committed
3

4
import networkx as nx
Chen Sirui's avatar
Chen Sirui committed
5
6
import numpy as np
import torch
7
8
from torch.utils.data import DataLoader, Dataset

Chen Sirui's avatar
Chen Sirui committed
9
10
11
12
13
14
15
16
17
18
19
20
import dgl


def build_dense_graph(n_particles):
    g = nx.complete_graph(n_particles)
    return dgl.from_networkx(g)


class MultiBodyDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.zipfile = np.load(self.path)
21
22
23
        self.node_state = self.zipfile["data"]
        self.node_label = self.zipfile["label"]
        self.n_particles = self.zipfile["n_particles"]
Chen Sirui's avatar
Chen Sirui committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37

    def __len__(self):
        return self.node_state.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        node_state = self.node_state[idx, :, :]
        node_label = self.node_label[idx, :, :]
        return (node_state, node_label)


class MultiBodyTrainDataset(MultiBodyDataset):
38
    def __init__(self, data_path="./data/"):
Chen Sirui's avatar
Chen Sirui committed
39
        super(MultiBodyTrainDataset, self).__init__(
40
41
42
43
44
            data_path + "n_body_train.npz"
        )
        self.stat_median = self.zipfile["median"]
        self.stat_max = self.zipfile["max"]
        self.stat_min = self.zipfile["min"]
Chen Sirui's avatar
Chen Sirui committed
45
46
47


class MultiBodyValidDataset(MultiBodyDataset):
48
    def __init__(self, data_path="./data/"):
Chen Sirui's avatar
Chen Sirui committed
49
        super(MultiBodyValidDataset, self).__init__(
50
51
            data_path + "n_body_valid.npz"
        )
Chen Sirui's avatar
Chen Sirui committed
52
53
54


class MultiBodyTestDataset(MultiBodyDataset):
55
56
57
58
59
60
61
    def __init__(self, data_path="./data/"):
        super(MultiBodyTestDataset, self).__init__(
            data_path + "n_body_test.npz"
        )
        self.test_traj = self.zipfile["test_traj"]
        self.first_frame = torch.from_numpy(self.zipfile["first_frame"])

Chen Sirui's avatar
Chen Sirui committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

# Construct fully connected graph


class MultiBodyGraphCollator:
    def __init__(self, n_particles):
        self.n_particles = n_particles
        self.graph = dgl.from_networkx(nx.complete_graph(self.n_particles))

    def __call__(self, batch):
        graph_list = []
        data_list = []
        label_list = []
        for frame in batch:
            graph_list.append(copy.deepcopy(self.graph))
            data_list.append(torch.from_numpy(frame[0]))
            label_list.append(torch.from_numpy(frame[1]))

        graph_batch = dgl.batch(graph_list)
        data_batch = torch.vstack(data_list)
        label_batch = torch.vstack(label_list)
        return graph_batch, data_batch, label_batch