data.py 5.13 KB
Newer Older
rusty1s's avatar
update  
rusty1s committed
1
2
from typing import Tuple

rusty1s's avatar
rusty1s committed
3
4
import torch
import torch_geometric.transforms as T
rusty1s's avatar
update  
rusty1s committed
5
from torch_geometric.data import Data, Batch
rusty1s's avatar
rusty1s committed
6
7
8
from torch_geometric.datasets import (Planetoid, WikiCS, Coauthor, Amazon,
                                      GNNBenchmarkDataset, Yelp, Flickr,
                                      Reddit2, PPI)
rusty1s's avatar
update  
rusty1s committed
9
from ogb.nodeproppred import PygNodePropPredDataset
rusty1s's avatar
rusty1s committed
10
11
12
13

from .utils import index2mask, gen_masks


rusty1s's avatar
update  
rusty1s committed
14
15
16
def get_planetoid(root: str, name: str) -> Tuple[Data, int, int]:
    transform = T.Compose([T.NormalizeFeatures(), T.ToSparseTensor()])
    dataset = Planetoid(f'{root}/Planetoid', name, transform=transform)
rusty1s's avatar
rusty1s committed
17
18
19
    return dataset[0], dataset.num_features, dataset.num_classes


rusty1s's avatar
update  
rusty1s committed
20
def get_wikics(root: str) -> Tuple[Data, int, int]:
rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
27
28
    dataset = WikiCS(f'{root}/WIKICS', transform=T.ToSparseTensor())
    data = dataset[0]
    data.adj_t = data.adj_t.to_symmetric()
    data.val_mask = data.stopping_mask
    data.stopping_mask = None
    return data, dataset.num_features, dataset.num_classes


rusty1s's avatar
update  
rusty1s committed
29
def get_coauthor(root: str, name: str) -> Tuple[Data, int, int]:
rusty1s's avatar
rusty1s committed
30
31
32
33
34
35
36
37
    dataset = Coauthor(f'{root}/Coauthor', name, transform=T.ToSparseTensor())
    data = dataset[0]
    torch.manual_seed(12345)
    data.train_mask, data.val_mask, data.test_mask = gen_masks(
        data.y, 20, 30, 20)
    return data, dataset.num_features, dataset.num_classes


rusty1s's avatar
update  
rusty1s committed
38
def get_amazon(root: str, name: str) -> Tuple[Data, int, int]:
rusty1s's avatar
rusty1s committed
39
40
41
42
43
44
45
46
    dataset = Amazon(f'{root}/Amazon', name, transform=T.ToSparseTensor())
    data = dataset[0]
    torch.manual_seed(12345)
    data.train_mask, data.val_mask, data.test_mask = gen_masks(
        data.y, 20, 30, 20)
    return data, dataset.num_features, dataset.num_classes


rusty1s's avatar
update  
rusty1s committed
47
def get_arxiv(root: str) -> Tuple[Data, int, int]:
rusty1s's avatar
rusty1s committed
48
49
50
51
52
53
54
55
56
57
58
59
60
    dataset = PygNodePropPredDataset('ogbn-arxiv', f'{root}/OGB',
                                     pre_transform=T.ToSparseTensor())
    data = dataset[0]
    data.adj_t = data.adj_t.to_symmetric()
    data.node_year = None
    data.y = data.y.view(-1)
    split_idx = dataset.get_idx_split()
    data.train_mask = index2mask(split_idx['train'], data.num_nodes)
    data.val_mask = index2mask(split_idx['valid'], data.num_nodes)
    data.test_mask = index2mask(split_idx['test'], data.num_nodes)
    return data, dataset.num_features, dataset.num_classes


rusty1s's avatar
update  
rusty1s committed
61
def get_products(root: str) -> Tuple[Data, int, int]:
rusty1s's avatar
rusty1s committed
62
63
64
65
66
67
68
69
70
71
72
    dataset = PygNodePropPredDataset('ogbn-products', f'{root}/OGB',
                                     pre_transform=T.ToSparseTensor())
    data = dataset[0]
    data.y = data.y.view(-1)
    split_idx = dataset.get_idx_split()
    data.train_mask = index2mask(split_idx['train'], data.num_nodes)
    data.val_mask = index2mask(split_idx['valid'], data.num_nodes)
    data.test_mask = index2mask(split_idx['test'], data.num_nodes)
    return data, dataset.num_features, dataset.num_classes


rusty1s's avatar
update  
rusty1s committed
73
def get_yelp(root: str) -> Tuple[Data, int, int]:
rusty1s's avatar
rusty1s committed
74
75
76
77
78
79
    dataset = Yelp(f'{root}/YELP', pre_transform=T.ToSparseTensor())
    data = dataset[0]
    data.x = (data.x - data.x.mean(dim=0)) / data.x.std(dim=0)
    return data, dataset.num_features, dataset.num_classes


rusty1s's avatar
update  
rusty1s committed
80
def get_flickr(root: str) -> Tuple[Data, int, int]:
rusty1s's avatar
rusty1s committed
81
82
83
84
    dataset = Flickr(f'{root}/Flickr', pre_transform=T.ToSparseTensor())
    return dataset[0], dataset.num_features, dataset.num_classes


rusty1s's avatar
update  
rusty1s committed
85
def get_reddit(root: str) -> Tuple[Data, int, int]:
rusty1s's avatar
rusty1s committed
86
87
88
89
90
91
    dataset = Reddit2(f'{root}/Reddit2', pre_transform=T.ToSparseTensor())
    data = dataset[0]
    data.x = (data.x - data.x.mean(dim=0)) / data.x.std(dim=0)
    return data, dataset.num_features, dataset.num_classes


rusty1s's avatar
update  
rusty1s committed
92
def get_ppi(root: str, split: str = 'train') -> Tuple[Data, int, int]:
rusty1s's avatar
rusty1s committed
93
94
95
96
97
98
99
100
    dataset = PPI(f'{root}/PPI', split=split, pre_transform=T.ToSparseTensor())
    data = Batch.from_data_list(dataset)
    data.batch = None
    data.ptr = None
    data[f'{split}_mask'] = torch.ones(data.num_nodes, dtype=torch.bool)
    return data, dataset.num_features, dataset.num_classes


rusty1s's avatar
update  
rusty1s committed
101
def get_sbm(root: str, name: str) -> Tuple[Data, int, int]:
rusty1s's avatar
rusty1s committed
102
103
104
105
106
107
108
109
    dataset = GNNBenchmarkDataset(f'{root}/SBM', name, split='train',
                                  pre_transform=T.ToSparseTensor())
    data = Batch.from_data_list(dataset)
    data.batch = None
    data.ptr = None
    return data, dataset.num_features, dataset.num_classes


rusty1s's avatar
update  
rusty1s committed
110
def get_data(root: str, name: str) -> Tuple[Data, int, int]:
rusty1s's avatar
rusty1s committed
111
112
113
114
115
116
    if name.lower() in ['cora', 'citeseer', 'pubmed']:
        return get_planetoid(root, name)
    if name.lower() in ['coauthorcs', 'coauthorphysics']:
        return get_coauthor(root, name[8:])
    if name.lower() in ['amazoncomputers', 'amazonphoto']:
        return get_amazon(root, name[6:])
rusty1s's avatar
update  
rusty1s committed
117
118
119
120
    if name.lower() == 'wikics':
        return get_wikics(root)
    if name.lower() in ['cluster', 'pattern']:
        return get_sbm(root, name)
rusty1s's avatar
rusty1s committed
121
122
123
124
    if name.lower() == 'reddit':
        return get_reddit(root)
    if name.lower() == 'ppi':
        return get_ppi(root)
rusty1s's avatar
update  
rusty1s committed
125
126
127
128
129
130
131
132
    if name.lower() == 'flickr':
        return get_flickr(root)
    if name.lower() == 'yelp':
        return get_yelp(root)
    if name.lower() in ['ogbn-arxiv', 'arxiv']:
        return get_arxiv(root)
    if name.lower() in ['ogbn-products', 'products']:
        return get_products(root)
rusty1s's avatar
rusty1s committed
133
    raise NotImplementedError