dataloading.py 2.62 KB
Newer Older
Chen Sirui's avatar
Chen Sirui committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import ssl
from six.moves import urllib
import torch
import numpy as np
import dgl
from torch.utils.data import Dataset, DataLoader


def download_file(dataset):
    print("Start Downloading data: {}".format(dataset))
    url = "https://s3.us-west-2.amazonaws.com/dgl-data/dataset/{}".format(
        dataset)
    print("Start Downloading File....")
    context = ssl._create_unverified_context()
    data = urllib.request.urlopen(url, context=context)
    with open("./data/{}".format(dataset), "wb") as handle:
        handle.write(data.read())


class SnapShotDataset(Dataset):
    def __init__(self, path, npz_file):
        if not os.path.exists(path+'/'+npz_file):
            if not os.path.exists(path):
                os.mkdir(path)
            download_file(npz_file)
        zipfile = np.load(path+'/'+npz_file)
        self.x = zipfile['x']
        self.y = zipfile['y']

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        return self.x[idx, ...], self.y[idx, ...]


def METR_LAGraphDataset():
    if not os.path.exists('data/graph_la.bin'):
        if not os.path.exists('data'):
            os.mkdir('data')
        download_file('graph_la.bin')
    g, _ = dgl.load_graphs('data/graph_la.bin')
    return g[0]


class METR_LATrainDataset(SnapShotDataset):
    def __init__(self):
        super(METR_LATrainDataset, self).__init__('data', 'metr_la_train.npz')
        self.mean = self.x[..., 0].mean()
        self.std = self.x[..., 0].std()


class METR_LATestDataset(SnapShotDataset):
    def __init__(self):
        super(METR_LATestDataset, self).__init__('data', 'metr_la_test.npz')


class METR_LAValidDataset(SnapShotDataset):
    def __init__(self):
        super(METR_LAValidDataset, self).__init__('data', 'metr_la_valid.npz')


def PEMS_BAYGraphDataset():
    if not os.path.exists('data/graph_bay.bin'):
        if not os.path.exists('data'):
            os.mkdir('data')
        download_file('graph_bay.bin')
    g, _ = dgl.load_graphs('data/graph_bay.bin')
    return g[0]


class PEMS_BAYTrainDataset(SnapShotDataset):
    def __init__(self):
        super(PEMS_BAYTrainDataset, self).__init__(
            'data', 'pems_bay_train.npz')
        self.mean = self.x[..., 0].mean()
        self.std = self.x[..., 0].std()


class PEMS_BAYTestDataset(SnapShotDataset):
    def __init__(self):
        super(PEMS_BAYTestDataset, self).__init__('data', 'pems_bay_test.npz')


class PEMS_BAYValidDataset(SnapShotDataset):
    def __init__(self):
        super(PEMS_BAYValidDataset, self).__init__(
            'data', 'pems_bay_valid.npz')