Commit 74b1b814 authored by rusty1s's avatar rusty1s
Browse files

update

parent 2f25da6c
......@@ -9,15 +9,21 @@ for library in ['_relabel', '_async']:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin)
from .history import History # noqa
from .loader import SubgraphLoader # noqa
from .data import get_data # noqa
from .history import History # noqa
from .pool import AsyncIOPool # noqa
from .metis import metis, permute # noqa
from .utils import compute_acc # noqa
from .loader import SubgraphLoader, EvalSubgraphLoader # noqa
__all__ = [
'History',
'SubgraphLoader',
'get_data',
'History',
'AsyncIOPool',
'metis',
'permute',
'compute_acc',
'SubgraphLoader',
'EvalSubgraphLoader',
'__version__',
]
from typing import Tuple
import torch
import torch_geometric.transforms as T
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.data import Batch
from torch_geometric.data import Data, Batch
from torch_geometric.datasets import (Planetoid, WikiCS, Coauthor, Amazon,
GNNBenchmarkDataset, Yelp, Flickr,
Reddit2, PPI)
from ogb.nodeproppred import PygNodePropPredDataset
from .utils import index2mask, gen_masks
def get_planetoid(root, name):
dataset = Planetoid(
f'{root}/Planetoid', name,
transform=T.Compose([T.NormalizeFeatures(),
T.ToSparseTensor()]))
def get_planetoid(root: str, name: str) -> Tuple[Data, int, int]:
transform = T.Compose([T.NormalizeFeatures(), T.ToSparseTensor()])
dataset = Planetoid(f'{root}/Planetoid', name, transform=transform)
return dataset[0], dataset.num_features, dataset.num_classes
def get_wikics(root):
def get_wikics(root: str) -> Tuple[Data, int, int]:
dataset = WikiCS(f'{root}/WIKICS', transform=T.ToSparseTensor())
data = dataset[0]
data.adj_t = data.adj_t.to_symmetric()
......@@ -27,7 +26,7 @@ def get_wikics(root):
return data, dataset.num_features, dataset.num_classes
def get_coauthor(root, name):
def get_coauthor(root: str, name: str) -> Tuple[Data, int, int]:
dataset = Coauthor(f'{root}/Coauthor', name, transform=T.ToSparseTensor())
data = dataset[0]
torch.manual_seed(12345)
......@@ -36,7 +35,7 @@ def get_coauthor(root, name):
return data, dataset.num_features, dataset.num_classes
def get_amazon(root, name):
def get_amazon(root: str, name: str) -> Tuple[Data, int, int]:
dataset = Amazon(f'{root}/Amazon', name, transform=T.ToSparseTensor())
data = dataset[0]
torch.manual_seed(12345)
......@@ -45,118 +44,90 @@ def get_amazon(root, name):
return data, dataset.num_features, dataset.num_classes
def get_arxiv(root):
def get_arxiv(root: str) -> Tuple[Data, int, int]:
dataset = PygNodePropPredDataset('ogbn-arxiv', f'{root}/OGB',
pre_transform=T.ToSparseTensor())
data = dataset[0]
data.adj_t = data.adj_t.to_symmetric()
data.node_year = None
data.y = data.y.view(-1)
split_idx = dataset.get_idx_split()
data.train_mask = index2mask(split_idx['train'], data.num_nodes)
data.val_mask = index2mask(split_idx['valid'], data.num_nodes)
data.test_mask = index2mask(split_idx['test'], data.num_nodes)
return data, dataset.num_features, dataset.num_classes
def get_products(root):
def get_products(root: str) -> Tuple[Data, int, int]:
dataset = PygNodePropPredDataset('ogbn-products', f'{root}/OGB',
pre_transform=T.ToSparseTensor())
data = dataset[0]
data.y = data.y.view(-1)
split_idx = dataset.get_idx_split()
data.train_mask = index2mask(split_idx['train'], data.num_nodes)
data.val_mask = index2mask(split_idx['valid'], data.num_nodes)
data.test_mask = index2mask(split_idx['test'], data.num_nodes)
return data, dataset.num_features, dataset.num_classes
def get_proteins(root):
dataset = PygNodePropPredDataset('ogbn-proteins', f'{root}/OGB',
pre_transform=T.ToSparseTensor())
data = dataset[0]
data.node_species = None
data.y = data.y.to(torch.float)
split_idx = dataset.get_idx_split()
data.train_mask = index2mask(split_idx['train'], data.num_nodes)
data.val_mask = index2mask(split_idx['valid'], data.num_nodes)
data.test_mask = index2mask(split_idx['test'], data.num_nodes)
return data, dataset.num_features, data.y.size(-1)
def get_yelp(root):
def get_yelp(root: str) -> Tuple[Data, int, int]:
dataset = Yelp(f'{root}/YELP', pre_transform=T.ToSparseTensor())
data = dataset[0]
data.x = (data.x - data.x.mean(dim=0)) / data.x.std(dim=0)
return data, dataset.num_features, dataset.num_classes
def get_flickr(root):
def get_flickr(root: str) -> Tuple[Data, int, int]:
dataset = Flickr(f'{root}/Flickr', pre_transform=T.ToSparseTensor())
return dataset[0], dataset.num_features, dataset.num_classes
def get_reddit(root):
def get_reddit(root: str) -> Tuple[Data, int, int]:
dataset = Reddit2(f'{root}/Reddit2', pre_transform=T.ToSparseTensor())
data = dataset[0]
data.x = (data.x - data.x.mean(dim=0)) / data.x.std(dim=0)
return data, dataset.num_features, dataset.num_classes
def get_ppi(root, split='train'):
def get_ppi(root: str, split: str = 'train') -> Tuple[Data, int, int]:
dataset = PPI(f'{root}/PPI', split=split, pre_transform=T.ToSparseTensor())
data = Batch.from_data_list(dataset)
data.batch = None
data.ptr = None
data[f'{split}_mask'] = torch.ones(data.num_nodes, dtype=torch.bool)
return data, dataset.num_features, dataset.num_classes
def get_sbm(root, name):
def get_sbm(root: str, name: str) -> Tuple[Data, int, int]:
dataset = GNNBenchmarkDataset(f'{root}/SBM', name, split='train',
pre_transform=T.ToSparseTensor())
data = Batch.from_data_list(dataset)
data.batch = None
data.ptr = None
return data, dataset.num_features, dataset.num_classes
def get_data(root, name):
def get_data(root: str, name: str) -> Tuple[Data, int, int]:
if name.lower() in ['cora', 'citeseer', 'pubmed']:
return get_planetoid(root, name)
if name.lower() == 'wikics':
return get_wikics(root)
if name.lower() in ['coauthorcs', 'coauthorphysics']:
return get_coauthor(root, name[8:])
if name.lower() in ['amazoncomputers', 'amazonphoto']:
return get_amazon(root, name[6:])
if name.lower() in ['ogbn-arxiv', 'arxiv']:
return get_arxiv(root)
if name.lower() in ['ogbn-products', 'products']:
return get_products(root)
if name.lower() == ['ogbn-proteins', 'proteins']:
return get_proteins(root)
if name.lower() == 'yelp':
return get_yelp(root)
if name.lower() == 'flickr':
return get_flickr(root)
if name.lower() == 'wikics':
return get_wikics(root)
if name.lower() in ['cluster', 'pattern']:
return get_sbm(root, name)
if name.lower() == 'reddit':
return get_reddit(root)
if name.lower() == 'ppi':
return get_ppi(root)
if name.lower() in ['cluster', 'pattern']:
return get_sbm(root, name)
if name.lower() == 'flickr':
return get_flickr(root)
if name.lower() == 'yelp':
return get_yelp(root)
if name.lower() in ['ogbn-arxiv', 'arxiv']:
return get_arxiv(root)
if name.lower() in ['ogbn-products', 'products']:
return get_products(root)
raise NotImplementedError
......@@ -7,8 +7,6 @@ from torch import Tensor
from torch_sparse import SparseTensor
from torch_geometric.data import Data
partition_fn = torch.ops.torch_sparse.partition
def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
log: bool = True) -> Tuple[Tensor, Tensor]:
......@@ -24,7 +22,8 @@ def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
perm, ptr = torch.arange(num_nodes), torch.tensor([0, num_nodes])
else:
rowptr, col, _ = adj_t.csr()
cluster = partition_fn(rowptr, col, None, num_parts, recursive)
cluster = torch.ops.torch_sparse.partition(rowptr, col, None,
num_parts, recursive)
cluster, perm = cluster.sort()
ptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)
......
from typing import Optional
from typing import Optional, Tuple
import torch
from torch import Tensor
......@@ -10,7 +10,8 @@ def index2mask(idx: Tensor, size: int) -> Tensor:
return mask
def compute_acc(logits: Tensor, y: Tensor, mask: Optional[Tensor] = None):
def compute_acc(logits: Tensor, y: Tensor,
mask: Optional[Tensor] = None) -> float:
if mask is not None:
logits, y = logits[mask], y[mask]
......@@ -29,7 +30,7 @@ def compute_acc(logits: Tensor, y: Tensor, mask: Optional[Tensor] = None):
def gen_masks(y: Tensor, train_per_class: int = 20, val_per_class: int = 30,
num_splits: int = 20):
num_splits: int = 20) -> Tuple[Tensor, Tensor, Tensor]:
num_classes = int(y.max()) + 1
train_mask = torch.zeros(y.size(0), num_splits, dtype=torch.bool)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment