dataloading.py 2.6 KB
Newer Older
Chen Sirui's avatar
Chen Sirui committed
1
2
import os
import ssl
3

Chen Sirui's avatar
Chen Sirui committed
4
import numpy as np
5
6
7
8
import torch
from six.moves import urllib
from torch.utils.data import DataLoader, Dataset

Chen Sirui's avatar
Chen Sirui committed
9
10
11
12
13
14
import dgl


def download_file(dataset):
    print("Start Downloading data: {}".format(dataset))
    url = "https://s3.us-west-2.amazonaws.com/dgl-data/dataset/{}".format(
15
16
        dataset
    )
Chen Sirui's avatar
Chen Sirui committed
17
18
19
20
21
22
23
24
25
    print("Start Downloading File....")
    context = ssl._create_unverified_context()
    data = urllib.request.urlopen(url, context=context)
    with open("./data/{}".format(dataset), "wb") as handle:
        handle.write(data.read())


class SnapShotDataset(Dataset):
    def __init__(self, path, npz_file):
26
        if not os.path.exists(path + "/" + npz_file):
Chen Sirui's avatar
Chen Sirui committed
27
28
29
            if not os.path.exists(path):
                os.mkdir(path)
            download_file(npz_file)
30
31
32
        zipfile = np.load(path + "/" + npz_file)
        self.x = zipfile["x"]
        self.y = zipfile["y"]
Chen Sirui's avatar
Chen Sirui committed
33
34
35
36
37
38
39
40
41
42
43
44

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        return self.x[idx, ...], self.y[idx, ...]


def METR_LAGraphDataset():
45
46
47
48
49
    if not os.path.exists("data/graph_la.bin"):
        if not os.path.exists("data"):
            os.mkdir("data")
        download_file("graph_la.bin")
    g, _ = dgl.load_graphs("data/graph_la.bin")
Chen Sirui's avatar
Chen Sirui committed
50
51
52
53
54
    return g[0]


class METR_LATrainDataset(SnapShotDataset):
    def __init__(self):
55
        super(METR_LATrainDataset, self).__init__("data", "metr_la_train.npz")
Chen Sirui's avatar
Chen Sirui committed
56
57
58
59
60
61
        self.mean = self.x[..., 0].mean()
        self.std = self.x[..., 0].std()


class METR_LATestDataset(SnapShotDataset):
    def __init__(self):
62
        super(METR_LATestDataset, self).__init__("data", "metr_la_test.npz")
Chen Sirui's avatar
Chen Sirui committed
63
64
65
66


class METR_LAValidDataset(SnapShotDataset):
    def __init__(self):
67
        super(METR_LAValidDataset, self).__init__("data", "metr_la_valid.npz")
Chen Sirui's avatar
Chen Sirui committed
68
69
70


def PEMS_BAYGraphDataset():
71
72
73
74
75
    if not os.path.exists("data/graph_bay.bin"):
        if not os.path.exists("data"):
            os.mkdir("data")
        download_file("graph_bay.bin")
    g, _ = dgl.load_graphs("data/graph_bay.bin")
Chen Sirui's avatar
Chen Sirui committed
76
77
78
79
80
    return g[0]


class PEMS_BAYTrainDataset(SnapShotDataset):
    def __init__(self):
81
        super(PEMS_BAYTrainDataset, self).__init__("data", "pems_bay_train.npz")
Chen Sirui's avatar
Chen Sirui committed
82
83
84
85
86
87
        self.mean = self.x[..., 0].mean()
        self.std = self.x[..., 0].std()


class PEMS_BAYTestDataset(SnapShotDataset):
    def __init__(self):
88
        super(PEMS_BAYTestDataset, self).__init__("data", "pems_bay_test.npz")
Chen Sirui's avatar
Chen Sirui committed
89
90
91
92


class PEMS_BAYValidDataset(SnapShotDataset):
    def __init__(self):
93
        super(PEMS_BAYValidDataset, self).__init__("data", "pems_bay_valid.npz")