Commit 74b1b814 authored by rusty1s's avatar rusty1s
Browse files

update

parent 2f25da6c
...@@ -9,15 +9,21 @@ for library in ['_relabel', '_async']: ...@@ -9,15 +9,21 @@ for library in ['_relabel', '_async']:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin) library, [osp.dirname(__file__)]).origin)
from .history import History # noqa
from .loader import SubgraphLoader # noqa
from .data import get_data # noqa from .data import get_data # noqa
from .history import History # noqa
from .pool import AsyncIOPool # noqa
from .metis import metis, permute # noqa
from .utils import compute_acc # noqa from .utils import compute_acc # noqa
from .loader import SubgraphLoader, EvalSubgraphLoader # noqa
__all__ = [ __all__ = [
'History',
'SubgraphLoader',
'get_data', 'get_data',
'History',
'AsyncIOPool',
'metis',
'permute',
'compute_acc', 'compute_acc',
'SubgraphLoader',
'EvalSubgraphLoader',
'__version__', '__version__',
] ]
from typing import Tuple
import torch import torch
import torch_geometric.transforms as T import torch_geometric.transforms as T
from ogb.nodeproppred import PygNodePropPredDataset from torch_geometric.data import Data, Batch
from torch_geometric.data import Batch
from torch_geometric.datasets import (Planetoid, WikiCS, Coauthor, Amazon, from torch_geometric.datasets import (Planetoid, WikiCS, Coauthor, Amazon,
GNNBenchmarkDataset, Yelp, Flickr, GNNBenchmarkDataset, Yelp, Flickr,
Reddit2, PPI) Reddit2, PPI)
from ogb.nodeproppred import PygNodePropPredDataset
from .utils import index2mask, gen_masks from .utils import index2mask, gen_masks
def get_planetoid(root, name): def get_planetoid(root: str, name: str) -> Tuple[Data, int, int]:
dataset = Planetoid( transform = T.Compose([T.NormalizeFeatures(), T.ToSparseTensor()])
f'{root}/Planetoid', name, dataset = Planetoid(f'{root}/Planetoid', name, transform=transform)
transform=T.Compose([T.NormalizeFeatures(),
T.ToSparseTensor()]))
return dataset[0], dataset.num_features, dataset.num_classes return dataset[0], dataset.num_features, dataset.num_classes
def get_wikics(root): def get_wikics(root: str) -> Tuple[Data, int, int]:
dataset = WikiCS(f'{root}/WIKICS', transform=T.ToSparseTensor()) dataset = WikiCS(f'{root}/WIKICS', transform=T.ToSparseTensor())
data = dataset[0] data = dataset[0]
data.adj_t = data.adj_t.to_symmetric() data.adj_t = data.adj_t.to_symmetric()
...@@ -27,7 +26,7 @@ def get_wikics(root): ...@@ -27,7 +26,7 @@ def get_wikics(root):
return data, dataset.num_features, dataset.num_classes return data, dataset.num_features, dataset.num_classes
def get_coauthor(root, name): def get_coauthor(root: str, name: str) -> Tuple[Data, int, int]:
dataset = Coauthor(f'{root}/Coauthor', name, transform=T.ToSparseTensor()) dataset = Coauthor(f'{root}/Coauthor', name, transform=T.ToSparseTensor())
data = dataset[0] data = dataset[0]
torch.manual_seed(12345) torch.manual_seed(12345)
...@@ -36,7 +35,7 @@ def get_coauthor(root, name): ...@@ -36,7 +35,7 @@ def get_coauthor(root, name):
return data, dataset.num_features, dataset.num_classes return data, dataset.num_features, dataset.num_classes
def get_amazon(root, name): def get_amazon(root: str, name: str) -> Tuple[Data, int, int]:
dataset = Amazon(f'{root}/Amazon', name, transform=T.ToSparseTensor()) dataset = Amazon(f'{root}/Amazon', name, transform=T.ToSparseTensor())
data = dataset[0] data = dataset[0]
torch.manual_seed(12345) torch.manual_seed(12345)
...@@ -45,118 +44,90 @@ def get_amazon(root, name): ...@@ -45,118 +44,90 @@ def get_amazon(root, name):
return data, dataset.num_features, dataset.num_classes return data, dataset.num_features, dataset.num_classes
def get_arxiv(root): def get_arxiv(root: str) -> Tuple[Data, int, int]:
dataset = PygNodePropPredDataset('ogbn-arxiv', f'{root}/OGB', dataset = PygNodePropPredDataset('ogbn-arxiv', f'{root}/OGB',
pre_transform=T.ToSparseTensor()) pre_transform=T.ToSparseTensor())
data = dataset[0] data = dataset[0]
data.adj_t = data.adj_t.to_symmetric() data.adj_t = data.adj_t.to_symmetric()
data.node_year = None data.node_year = None
data.y = data.y.view(-1) data.y = data.y.view(-1)
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
data.train_mask = index2mask(split_idx['train'], data.num_nodes) data.train_mask = index2mask(split_idx['train'], data.num_nodes)
data.val_mask = index2mask(split_idx['valid'], data.num_nodes) data.val_mask = index2mask(split_idx['valid'], data.num_nodes)
data.test_mask = index2mask(split_idx['test'], data.num_nodes) data.test_mask = index2mask(split_idx['test'], data.num_nodes)
return data, dataset.num_features, dataset.num_classes return data, dataset.num_features, dataset.num_classes
def get_products(root): def get_products(root: str) -> Tuple[Data, int, int]:
dataset = PygNodePropPredDataset('ogbn-products', f'{root}/OGB', dataset = PygNodePropPredDataset('ogbn-products', f'{root}/OGB',
pre_transform=T.ToSparseTensor()) pre_transform=T.ToSparseTensor())
data = dataset[0] data = dataset[0]
data.y = data.y.view(-1) data.y = data.y.view(-1)
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
data.train_mask = index2mask(split_idx['train'], data.num_nodes) data.train_mask = index2mask(split_idx['train'], data.num_nodes)
data.val_mask = index2mask(split_idx['valid'], data.num_nodes) data.val_mask = index2mask(split_idx['valid'], data.num_nodes)
data.test_mask = index2mask(split_idx['test'], data.num_nodes) data.test_mask = index2mask(split_idx['test'], data.num_nodes)
return data, dataset.num_features, dataset.num_classes return data, dataset.num_features, dataset.num_classes
def get_proteins(root): def get_yelp(root: str) -> Tuple[Data, int, int]:
dataset = PygNodePropPredDataset('ogbn-proteins', f'{root}/OGB',
pre_transform=T.ToSparseTensor())
data = dataset[0]
data.node_species = None
data.y = data.y.to(torch.float)
split_idx = dataset.get_idx_split()
data.train_mask = index2mask(split_idx['train'], data.num_nodes)
data.val_mask = index2mask(split_idx['valid'], data.num_nodes)
data.test_mask = index2mask(split_idx['test'], data.num_nodes)
return data, dataset.num_features, data.y.size(-1)
def get_yelp(root):
dataset = Yelp(f'{root}/YELP', pre_transform=T.ToSparseTensor()) dataset = Yelp(f'{root}/YELP', pre_transform=T.ToSparseTensor())
data = dataset[0] data = dataset[0]
data.x = (data.x - data.x.mean(dim=0)) / data.x.std(dim=0) data.x = (data.x - data.x.mean(dim=0)) / data.x.std(dim=0)
return data, dataset.num_features, dataset.num_classes return data, dataset.num_features, dataset.num_classes
def get_flickr(root): def get_flickr(root: str) -> Tuple[Data, int, int]:
dataset = Flickr(f'{root}/Flickr', pre_transform=T.ToSparseTensor()) dataset = Flickr(f'{root}/Flickr', pre_transform=T.ToSparseTensor())
return dataset[0], dataset.num_features, dataset.num_classes return dataset[0], dataset.num_features, dataset.num_classes
def get_reddit(root): def get_reddit(root: str) -> Tuple[Data, int, int]:
dataset = Reddit2(f'{root}/Reddit2', pre_transform=T.ToSparseTensor()) dataset = Reddit2(f'{root}/Reddit2', pre_transform=T.ToSparseTensor())
data = dataset[0] data = dataset[0]
data.x = (data.x - data.x.mean(dim=0)) / data.x.std(dim=0) data.x = (data.x - data.x.mean(dim=0)) / data.x.std(dim=0)
return data, dataset.num_features, dataset.num_classes return data, dataset.num_features, dataset.num_classes
def get_ppi(root, split='train'): def get_ppi(root: str, split: str = 'train') -> Tuple[Data, int, int]:
dataset = PPI(f'{root}/PPI', split=split, pre_transform=T.ToSparseTensor()) dataset = PPI(f'{root}/PPI', split=split, pre_transform=T.ToSparseTensor())
data = Batch.from_data_list(dataset) data = Batch.from_data_list(dataset)
data.batch = None data.batch = None
data.ptr = None data.ptr = None
data[f'{split}_mask'] = torch.ones(data.num_nodes, dtype=torch.bool) data[f'{split}_mask'] = torch.ones(data.num_nodes, dtype=torch.bool)
return data, dataset.num_features, dataset.num_classes return data, dataset.num_features, dataset.num_classes
def get_sbm(root, name): def get_sbm(root: str, name: str) -> Tuple[Data, int, int]:
dataset = GNNBenchmarkDataset(f'{root}/SBM', name, split='train', dataset = GNNBenchmarkDataset(f'{root}/SBM', name, split='train',
pre_transform=T.ToSparseTensor()) pre_transform=T.ToSparseTensor())
data = Batch.from_data_list(dataset) data = Batch.from_data_list(dataset)
data.batch = None data.batch = None
data.ptr = None data.ptr = None
return data, dataset.num_features, dataset.num_classes return data, dataset.num_features, dataset.num_classes
def get_data(root, name): def get_data(root: str, name: str) -> Tuple[Data, int, int]:
if name.lower() in ['cora', 'citeseer', 'pubmed']: if name.lower() in ['cora', 'citeseer', 'pubmed']:
return get_planetoid(root, name) return get_planetoid(root, name)
if name.lower() == 'wikics':
return get_wikics(root)
if name.lower() in ['coauthorcs', 'coauthorphysics']: if name.lower() in ['coauthorcs', 'coauthorphysics']:
return get_coauthor(root, name[8:]) return get_coauthor(root, name[8:])
if name.lower() in ['amazoncomputers', 'amazonphoto']: if name.lower() in ['amazoncomputers', 'amazonphoto']:
return get_amazon(root, name[6:]) return get_amazon(root, name[6:])
if name.lower() in ['ogbn-arxiv', 'arxiv']: if name.lower() == 'wikics':
return get_arxiv(root) return get_wikics(root)
if name.lower() in ['ogbn-products', 'products']: if name.lower() in ['cluster', 'pattern']:
return get_products(root) return get_sbm(root, name)
if name.lower() == ['ogbn-proteins', 'proteins']:
return get_proteins(root)
if name.lower() == 'yelp':
return get_yelp(root)
if name.lower() == 'flickr':
return get_flickr(root)
if name.lower() == 'reddit': if name.lower() == 'reddit':
return get_reddit(root) return get_reddit(root)
if name.lower() == 'ppi': if name.lower() == 'ppi':
return get_ppi(root) return get_ppi(root)
if name.lower() in ['cluster', 'pattern']: if name.lower() == 'flickr':
return get_sbm(root, name) return get_flickr(root)
if name.lower() == 'yelp':
return get_yelp(root)
if name.lower() in ['ogbn-arxiv', 'arxiv']:
return get_arxiv(root)
if name.lower() in ['ogbn-products', 'products']:
return get_products(root)
raise NotImplementedError raise NotImplementedError
...@@ -7,8 +7,6 @@ from torch import Tensor ...@@ -7,8 +7,6 @@ from torch import Tensor
from torch_sparse import SparseTensor from torch_sparse import SparseTensor
from torch_geometric.data import Data from torch_geometric.data import Data
partition_fn = torch.ops.torch_sparse.partition
def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False, def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
log: bool = True) -> Tuple[Tensor, Tensor]: log: bool = True) -> Tuple[Tensor, Tensor]:
...@@ -24,7 +22,8 @@ def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False, ...@@ -24,7 +22,8 @@ def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
perm, ptr = torch.arange(num_nodes), torch.tensor([0, num_nodes]) perm, ptr = torch.arange(num_nodes), torch.tensor([0, num_nodes])
else: else:
rowptr, col, _ = adj_t.csr() rowptr, col, _ = adj_t.csr()
cluster = partition_fn(rowptr, col, None, num_parts, recursive) cluster = torch.ops.torch_sparse.partition(rowptr, col, None,
num_parts, recursive)
cluster, perm = cluster.sort() cluster, perm = cluster.sort()
ptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts) ptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)
......
from typing import Optional from typing import Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -10,7 +10,8 @@ def index2mask(idx: Tensor, size: int) -> Tensor: ...@@ -10,7 +10,8 @@ def index2mask(idx: Tensor, size: int) -> Tensor:
return mask return mask
def compute_acc(logits: Tensor, y: Tensor, mask: Optional[Tensor] = None): def compute_acc(logits: Tensor, y: Tensor,
mask: Optional[Tensor] = None) -> float:
if mask is not None: if mask is not None:
logits, y = logits[mask], y[mask] logits, y = logits[mask], y[mask]
...@@ -29,7 +30,7 @@ def compute_acc(logits: Tensor, y: Tensor, mask: Optional[Tensor] = None): ...@@ -29,7 +30,7 @@ def compute_acc(logits: Tensor, y: Tensor, mask: Optional[Tensor] = None):
def gen_masks(y: Tensor, train_per_class: int = 20, val_per_class: int = 30, def gen_masks(y: Tensor, train_per_class: int = 20, val_per_class: int = 30,
num_splits: int = 20): num_splits: int = 20) -> Tuple[Tensor, Tensor, Tensor]:
num_classes = int(y.max()) + 1 num_classes = int(y.max()) + 1
train_mask = torch.zeros(y.size(0), num_splits, dtype=torch.bool) train_mask = torch.zeros(y.size(0), num_splits, dtype=torch.bool)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment