dataloader.py 2.29 KB
Newer Older
Chen Sirui's avatar
Chen Sirui committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import copy

import numpy as np
import torch
import dgl
import networkx as nx
from torch.utils.data import Dataset, DataLoader


def build_dense_graph(n_particles):
    g = nx.complete_graph(n_particles)
    return dgl.from_networkx(g)


class MultiBodyDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.zipfile = np.load(self.path)
        self.node_state = self.zipfile['data']
        self.node_label = self.zipfile['label']
        self.n_particles = self.zipfile['n_particles']

    def __len__(self):
        return self.node_state.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        node_state = self.node_state[idx, :, :]
        node_label = self.node_label[idx, :, :]
        return (node_state, node_label)


class MultiBodyTrainDataset(MultiBodyDataset):
    def __init__(self, data_path='./data/'):
        super(MultiBodyTrainDataset, self).__init__(
            data_path+'n_body_train.npz')
        self.stat_median = self.zipfile['median']
        self.stat_max = self.zipfile['max']
        self.stat_min = self.zipfile['min']


class MultiBodyValidDataset(MultiBodyDataset):
    def __init__(self, data_path='./data/'):
        super(MultiBodyValidDataset, self).__init__(
            data_path+'n_body_valid.npz')


class MultiBodyTestDataset(MultiBodyDataset):
    def __init__(self, data_path='./data/'):
        super(MultiBodyTestDataset, self).__init__(data_path+'n_body_test.npz')
        self.test_traj = self.zipfile['test_traj']
        self.first_frame = torch.from_numpy(self.zipfile['first_frame'])

# Construct fully connected graph


class MultiBodyGraphCollator:
    def __init__(self, n_particles):
        self.n_particles = n_particles
        self.graph = dgl.from_networkx(nx.complete_graph(self.n_particles))

    def __call__(self, batch):
        graph_list = []
        data_list = []
        label_list = []
        for frame in batch:
            graph_list.append(copy.deepcopy(self.graph))
            data_list.append(torch.from_numpy(frame[0]))
            label_list.append(torch.from_numpy(frame[1]))

        graph_batch = dgl.batch(graph_list)
        data_batch = torch.vstack(data_list)
        label_batch = torch.vstack(label_list)
        return graph_batch, data_batch, label_batch