Commit 2f25da6c authored by rusty1s's avatar rusty1s
Browse files

initial commit

parent ac165af3
__pycache__/
_ext/
build/
dist/
.cache/
data/
.eggs/
*.egg-info/
outputs/
.coverage
*.pt
*.so
[style]
based_on_style = pep8
split_before_named_assigns = False
#include <Python.h>
#include <torch/script.h>
#ifdef WITH_CUDA
#include "cuda/async_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__async(void) { return NULL; }
#endif
void synchronize() {
#ifdef WITH_CUDA
synchronize_cuda();
#else
AT_ERROR("Not compiled with CUDA support");
#endif
}
void read_async(torch::Tensor src,
torch::optional<torch::Tensor> optional_offset,
torch::optional<torch::Tensor> optional_count,
torch::Tensor index, torch::Tensor dst, torch::Tensor buffer) {
#ifdef WITH_CUDA
read_async_cuda(src, optional_offset, optional_count, index, dst, buffer);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
}
void write_async(torch::Tensor src, torch::Tensor offset, torch::Tensor count,
torch::Tensor dst) {
#ifdef WITH_CUDA
write_async_cuda(src, offset, count, dst);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
}
static auto registry =
torch::RegisterOperators()
.op("torch_geometric_autoscale::synchronize", &synchronize)
.op("torch_geometric_autoscale::read_async", &read_async)
.op("torch_geometric_autoscale::write_async", &write_async);
#include "relabel_cpu.h"
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch::Tensor>
relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::Tensor idx, bool bipartite) {
AT_ASSERTM(!rowptr.is_cuda(), "Rowptr tensor must be a CPU tensor");
AT_ASSERTM(!col.is_cuda(), "Col tensor must be a CPU tensor");
if (optional_value.has_value()) {
auto value = optional_value.value();
AT_ASSERTM(!value.is_cuda(), "Value tensor must be a CPU tensor");
AT_ASSERTM(value.dim() == 1, "Value tensor must be one-dimensional");
}
AT_ASSERTM(!idx.is_cuda(), "Index tensor must be a CPU tensor");
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto idx_data = idx.data_ptr<int64_t>();
std::vector<int64_t> n_ids;
std::unordered_map<int64_t, int64_t> n_id_map;
std::unordered_map<int64_t, int64_t>::iterator it;
auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options());
auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();
out_rowptr_data[0] = 0;
int64_t v, w, c, row_start, row_end, offset = 0;
for (int64_t i = 0; i < idx.numel(); i++) {
v = idx_data[i];
n_id_map[v] = i;
offset += rowptr_data[v + 1] - rowptr_data[v];
out_rowptr_data[i + 1] = offset;
}
auto out_col = torch::empty(offset, col.options());
auto out_col_data = out_col.data_ptr<int64_t>();
torch::optional<torch::Tensor> out_value = torch::nullopt;
if (optional_value.has_value()) {
out_value = torch::empty(offset, optional_value.value().options());
AT_DISPATCH_ALL_TYPES(optional_value.value().scalar_type(), "relabel", [&] {
auto value_data = optional_value.value().data_ptr<scalar_t>();
auto out_value_data = out_value.value().data_ptr<scalar_t>();
offset = 0;
for (int64_t i = 0; i < idx.numel(); i++) {
v = idx_data[i];
row_start = rowptr_data[v], row_end = rowptr_data[v + 1];
for (int64_t j = row_start; j < row_end; j++) {
w = col_data[j];
it = n_id_map.find(w);
if (it == n_id_map.end()) {
c = idx.numel() + n_ids.size();
n_id_map[w] = c;
n_ids.push_back(w);
out_col_data[offset] = c;
} else {
out_col_data[offset] = it->second;
}
out_value_data[offset] = value_data[j];
offset++;
}
}
});
} else {
offset = 0;
for (int64_t i = 0; i < idx.numel(); i++) {
v = idx_data[i];
row_start = rowptr_data[v], row_end = rowptr_data[v + 1];
for (int64_t j = row_start; j < row_end; j++) {
w = col_data[j];
it = n_id_map.find(w);
if (it == n_id_map.end()) {
c = idx.numel() + n_ids.size();
n_id_map[w] = c;
n_ids.push_back(w);
out_col_data[offset] = c;
} else {
out_col_data[offset] = it->second;
}
offset++;
}
}
}
if (!bipartite)
out_rowptr = torch::cat(
{out_rowptr, torch::full({(int64_t)n_ids.size()}, out_col.numel(),
rowptr.options())});
idx = torch::cat({idx, torch::from_blob(n_ids.data(), {(int64_t)n_ids.size()},
idx.options())});
return std::make_tuple(out_rowptr, out_col, out_value, idx);
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch::Tensor>
relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::Tensor idx, bool bipartite);
#include "async_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "../thread.h"
Thread &getThread() {
static Thread thread;
return thread;
}
void synchronize_cuda() { getThread().synchronize(); }
void read_async_cuda(torch::Tensor src,
torch::optional<torch::Tensor> optional_offset,
torch::optional<torch::Tensor> optional_count,
torch::Tensor index, torch::Tensor dst,
torch::Tensor buffer) {
AT_ASSERTM(!src.is_cuda(), "Source tensor must be a CPU tensor");
AT_ASSERTM(!index.is_cuda(), "Index tensor must be a CPU tensor");
AT_ASSERTM(dst.is_cuda(), "Target tensor must be a CUDA tensor");
AT_ASSERTM(!buffer.is_cuda(), "Buffer tensor must be a CPU tensor");
AT_ASSERTM(buffer.is_pinned(), "Buffer tensor must be pinned");
AT_ASSERTM(src.is_contiguous(), "Source tensor must be contiguous");
AT_ASSERTM(dst.is_contiguous(), "Target tensor must be contiguous");
AT_ASSERTM(buffer.is_contiguous(), "Buffer tensor must be contiguous");
AT_ASSERTM(index.dim() == 1, "Index tensor must be one-dimensional");
int64_t numel = 0;
if (optional_offset.has_value()) {
AT_ASSERTM(src.is_pinned(), "Source tensor must be pinned");
auto offset = optional_offset.value();
AT_ASSERTM(!offset.is_cuda(), "Offset tensor must be a CPU tensor");
AT_ASSERTM(offset.is_contiguous(), "Offset tensor must be contiguous");
AT_ASSERTM(offset.dim() == 1, "Offset tensor must be one-dimensional");
AT_ASSERTM(optional_count.has_value(), "Count tensor is undefined");
auto count = optional_count.value();
AT_ASSERTM(!count.is_cuda(), "Count tensor must be a CPU tensor");
AT_ASSERTM(count.is_contiguous(), "Count tensor must be contiguous");
AT_ASSERTM(count.dim() == 1, "Count tensor must be one-dimensional");
AT_ASSERTM(offset.numel() == count.numel(), "Size mismatch");
numel = count.sum().data_ptr<int64_t>()[0];
}
AT_ASSERTM(numel + index.numel() <= buffer.size(0),
"Buffer tensor size too small");
AT_ASSERTM(numel + index.numel() <= dst.size(0),
"Target tensor size too small");
auto stream = at::cuda::getCurrentCUDAStream(src.get_device());
AT_ASSERTM(stream != at::cuda::getDefaultCUDAStream(src.get_device()),
"Asynchronous read requires a non-default CUDA stream");
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "read_async", [&] {
getThread().run([=] {
int64_t size = src.numel() / src.size(0);
auto src_data = src.data_ptr<scalar_t>();
auto dst_data = dst.data_ptr<scalar_t>();
if (optional_offset.has_value()) {
auto offset = optional_offset.value();
auto count = optional_count.value();
auto offset_data = offset.data_ptr<int64_t>();
auto count_data = count.data_ptr<int64_t>();
int64_t src_offset, dst_offset = 0, c;
for (int64_t i = 0; i < offset.numel(); i++) {
src_offset = offset_data[i], c = count_data[i];
AT_ASSERTM(src_offset + c <= src.size(0), "Invalid index");
AT_ASSERTM(dst_offset + c <= dst.size(0), "Invalid index");
cudaMemcpyAsync(
dst_data + (dst_offset * size), src_data + (src_offset * size),
c * size * sizeof(scalar_t), cudaMemcpyDeviceToHost, stream);
dst_offset += c;
}
}
auto _buffer = buffer.narrow(0, 0, index.numel()); // convert to non-const
torch::index_select_out(_buffer, src, 0, index);
int64_t dim = src.numel() / src.size(0);
cudaMemcpyAsync(dst_data + numel * size, buffer.data_ptr<scalar_t>(),
index.numel() * dim * sizeof(scalar_t),
cudaMemcpyHostToDevice, stream);
});
});
}
void write_async_cuda(torch::Tensor src, torch::Tensor offset,
torch::Tensor count, torch::Tensor dst) {
AT_ASSERTM(src.is_cuda(), "Source tensor must be a CUDA tensor");
AT_ASSERTM(!offset.is_cuda(), "Offset tensor must be a CPU tensor");
AT_ASSERTM(!count.is_cuda(), "Count tensor must be a CPU tensor");
AT_ASSERTM(!dst.is_cuda(), "Target tensor must be a CPU tensor");
AT_ASSERTM(dst.is_pinned(), "Target tensor must be pinned");
AT_ASSERTM(src.is_contiguous(), "Index tensor must be contiguous");
AT_ASSERTM(offset.is_contiguous(), "Offset tensor must be contiguous");
AT_ASSERTM(count.is_contiguous(), "Count tensor must be contiguous");
AT_ASSERTM(dst.is_contiguous(), "Target tensor must be contiguous");
AT_ASSERTM(offset.dim() == 1, "Offset tensor must be one-dimensional");
AT_ASSERTM(count.dim() == 1, "Count tensor must be one-dimensional");
AT_ASSERTM(offset.numel() == count.numel(), "Size mismatch");
auto stream = at::cuda::getCurrentCUDAStream(src.get_device());
AT_ASSERTM(stream != at::cuda::getDefaultCUDAStream(src.get_device()),
"Asynchronous write requires a non-default CUDA stream");
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "write_async", [&] {
int64_t size = src.numel() / src.size(0);
auto src_data = src.data_ptr<scalar_t>();
auto offset_data = offset.data_ptr<int64_t>();
auto count_data = count.data_ptr<int64_t>();
auto dst_data = dst.data_ptr<scalar_t>();
int64_t src_offset = 0, dst_offset, c;
for (int64_t i = 0; i < offset.numel(); i++) {
dst_offset = offset_data[i], c = count_data[i];
AT_ASSERTM(src_offset + c <= src.size(0), "Invalid index");
AT_ASSERTM(dst_offset + c <= dst.size(0), "Invalid index");
cudaMemcpyAsync(
dst_data + (dst_offset * size), src_data + (src_offset * size),
c * size * sizeof(scalar_t), cudaMemcpyDeviceToHost, stream);
src_offset += c;
}
});
}
#pragma once
#include <torch/extension.h>
void synchronize_cuda();
void read_async_cuda(torch::Tensor src,
torch::optional<torch::Tensor> optional_offset,
torch::optional<torch::Tensor> optional_count,
torch::Tensor index, torch::Tensor dst,
torch::Tensor buffer);
void write_async_cuda(torch::Tensor src, torch::Tensor offset,
torch::Tensor count, torch::Tensor dst);
#include <Python.h>
#include <torch/script.h>
#include "cpu/relabel_cpu.h"
#ifdef _WIN32
PyMODINIT_FUNC PyInit__relabel(void) { return NULL; }
#endif
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch::Tensor>
relabel_one_hop(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::Tensor idx, bool bipartite) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return relabel_one_hop_cpu(rowptr, col, optional_value, idx, bipartite);
}
}
static auto registry = torch::RegisterOperators().op(
"torch_geometric_autoscale::relabel_one_hop", &relabel_one_hop);
#pragma once
#include <condition_variable>
#include <future>
#include <queue>
#include <thread>
// A simple C++11 Thread Pool implementation with `num_workers=1`.
// See: https://github.com/progschj/ThreadPool
class Thread {
public:
Thread();
~Thread();
template <class F> void run(F &&f);
void synchronize();
private:
bool stop;
std::mutex mutex;
std::thread worker;
std ::condition_variable condition;
std::queue<std::future<void>> results;
std::queue<std::function<void()>> tasks;
};
inline Thread::Thread() : stop(false) {
worker = std::thread([this] {
while (true) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->mutex);
this->condition.wait(
lock, [this] { return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
});
}
inline Thread::~Thread() {
{
std::unique_lock<std::mutex> lock(mutex);
stop = true;
}
condition.notify_all();
worker.join();
}
template <class F> void Thread::run(F &&f) {
auto task = std::make_shared<std::packaged_task<void()>>(
std::bind(std::forward<F>(f)));
results.emplace(task->get_future());
{
std::unique_lock<std::mutex> lock(mutex);
tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
}
void Thread::synchronize() {
if (results.empty())
return;
results.front().get();
results.pop();
}
[metadata]
description-file = README.md
[aliases]
test = pytest
[tool:pytest]
addopts = --capture=no --cov
[flake8]
ignore=F811,W503,W504
import os
import sys
import glob
import os.path as osp
from setuptools import setup, find_packages
import torch
from torch.__config__ import parallel_info
from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
def get_extensions():
Extension = CppExtension
define_macros = []
libraries = []
extra_compile_args = {'cxx': []}
extra_link_args = []
info = parallel_info()
if 'parallel backend: OpenMP' in info and 'OpenMP not found' not in info:
extra_compile_args['cxx'] += ['-DAT_PARALLEL_OPENMP']
if sys.platform == 'win32':
extra_compile_args['cxx'] += ['/openmp']
else:
extra_compile_args['cxx'] += ['-fopenmp']
else:
print('Compiling without OpenMP...')
if WITH_CUDA:
Extension = CUDAExtension
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr']
extra_compile_args['nvcc'] = nvcc_flags
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
extensions = []
for main in main_files:
name = main.split(os.sep)[-1][:-4]
sources = [main]
path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp')
if osp.exists(path):
sources += [path]
path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu')
if WITH_CUDA and osp.exists(path):
sources += [path]
extension = Extension(
'torch_geometric_autoscale._' + name,
sources,
include_dirs=[extensions_dir],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
libraries=libraries,
)
extensions += [extension]
return extensions
__version__ = '0.0.0'
install_requires = []
setup_requires = ['pytest-runner']
tests_require = ['pytest', 'pytest-cov']
setup(
name='torch_geometric_autoscale',
version='0.0.0',
description='PyGas: Auto-Scaling in PyG',
python_requires='>=3.6',
install_requires=install_requires,
setup_requires=setup_requires,
tests_require=tests_require,
ext_modules=get_extensions(),
cmdclass={
'build_ext':
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
},
packages=find_packages(),
)
import importlib
import os.path as osp
import torch
__version__ = '0.0.0'
for library in ['_relabel', '_async']:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin)
from .history import History # noqa
from .loader import SubgraphLoader # noqa
from .data import get_data # noqa
from .utils import compute_acc # noqa
__all__ = [
'History',
'SubgraphLoader',
'get_data',
'compute_acc',
'__version__',
]
import torch
import torch_geometric.transforms as T
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.data import Batch
from torch_geometric.datasets import (Planetoid, WikiCS, Coauthor, Amazon,
GNNBenchmarkDataset, Yelp, Flickr,
Reddit2, PPI)
from .utils import index2mask, gen_masks
def get_planetoid(root, name):
dataset = Planetoid(
f'{root}/Planetoid', name,
transform=T.Compose([T.NormalizeFeatures(),
T.ToSparseTensor()]))
return dataset[0], dataset.num_features, dataset.num_classes
def get_wikics(root):
dataset = WikiCS(f'{root}/WIKICS', transform=T.ToSparseTensor())
data = dataset[0]
data.adj_t = data.adj_t.to_symmetric()
data.val_mask = data.stopping_mask
data.stopping_mask = None
return data, dataset.num_features, dataset.num_classes
def get_coauthor(root, name):
dataset = Coauthor(f'{root}/Coauthor', name, transform=T.ToSparseTensor())
data = dataset[0]
torch.manual_seed(12345)
data.train_mask, data.val_mask, data.test_mask = gen_masks(
data.y, 20, 30, 20)
return data, dataset.num_features, dataset.num_classes
def get_amazon(root, name):
dataset = Amazon(f'{root}/Amazon', name, transform=T.ToSparseTensor())
data = dataset[0]
torch.manual_seed(12345)
data.train_mask, data.val_mask, data.test_mask = gen_masks(
data.y, 20, 30, 20)
return data, dataset.num_features, dataset.num_classes
def get_arxiv(root):
dataset = PygNodePropPredDataset('ogbn-arxiv', f'{root}/OGB',
pre_transform=T.ToSparseTensor())
data = dataset[0]
data.adj_t = data.adj_t.to_symmetric()
data.node_year = None
data.y = data.y.view(-1)
split_idx = dataset.get_idx_split()
data.train_mask = index2mask(split_idx['train'], data.num_nodes)
data.val_mask = index2mask(split_idx['valid'], data.num_nodes)
data.test_mask = index2mask(split_idx['test'], data.num_nodes)
return data, dataset.num_features, dataset.num_classes
def get_products(root):
dataset = PygNodePropPredDataset('ogbn-products', f'{root}/OGB',
pre_transform=T.ToSparseTensor())
data = dataset[0]
data.y = data.y.view(-1)
split_idx = dataset.get_idx_split()
data.train_mask = index2mask(split_idx['train'], data.num_nodes)
data.val_mask = index2mask(split_idx['valid'], data.num_nodes)
data.test_mask = index2mask(split_idx['test'], data.num_nodes)
return data, dataset.num_features, dataset.num_classes
def get_proteins(root):
dataset = PygNodePropPredDataset('ogbn-proteins', f'{root}/OGB',
pre_transform=T.ToSparseTensor())
data = dataset[0]
data.node_species = None
data.y = data.y.to(torch.float)
split_idx = dataset.get_idx_split()
data.train_mask = index2mask(split_idx['train'], data.num_nodes)
data.val_mask = index2mask(split_idx['valid'], data.num_nodes)
data.test_mask = index2mask(split_idx['test'], data.num_nodes)
return data, dataset.num_features, data.y.size(-1)
def get_yelp(root):
dataset = Yelp(f'{root}/YELP', pre_transform=T.ToSparseTensor())
data = dataset[0]
data.x = (data.x - data.x.mean(dim=0)) / data.x.std(dim=0)
return data, dataset.num_features, dataset.num_classes
def get_flickr(root):
dataset = Flickr(f'{root}/Flickr', pre_transform=T.ToSparseTensor())
return dataset[0], dataset.num_features, dataset.num_classes
def get_reddit(root):
dataset = Reddit2(f'{root}/Reddit2', pre_transform=T.ToSparseTensor())
data = dataset[0]
data.x = (data.x - data.x.mean(dim=0)) / data.x.std(dim=0)
return data, dataset.num_features, dataset.num_classes
def get_ppi(root, split='train'):
dataset = PPI(f'{root}/PPI', split=split, pre_transform=T.ToSparseTensor())
data = Batch.from_data_list(dataset)
data.batch = None
data.ptr = None
data[f'{split}_mask'] = torch.ones(data.num_nodes, dtype=torch.bool)
return data, dataset.num_features, dataset.num_classes
def get_sbm(root, name):
dataset = GNNBenchmarkDataset(f'{root}/SBM', name, split='train',
pre_transform=T.ToSparseTensor())
data = Batch.from_data_list(dataset)
data.batch = None
data.ptr = None
return data, dataset.num_features, dataset.num_classes
def get_data(root, name):
if name.lower() in ['cora', 'citeseer', 'pubmed']:
return get_planetoid(root, name)
if name.lower() == 'wikics':
return get_wikics(root)
if name.lower() in ['coauthorcs', 'coauthorphysics']:
return get_coauthor(root, name[8:])
if name.lower() in ['amazoncomputers', 'amazonphoto']:
return get_amazon(root, name[6:])
if name.lower() in ['ogbn-arxiv', 'arxiv']:
return get_arxiv(root)
if name.lower() in ['ogbn-products', 'products']:
return get_products(root)
if name.lower() == ['ogbn-proteins', 'proteins']:
return get_proteins(root)
if name.lower() == 'yelp':
return get_yelp(root)
if name.lower() == 'flickr':
return get_flickr(root)
if name.lower() == 'reddit':
return get_reddit(root)
if name.lower() == 'ppi':
return get_ppi(root)
if name.lower() in ['cluster', 'pattern']:
return get_sbm(root, name)
raise NotImplementedError
from typing import Optional
import torch
from torch import Tensor
class History(torch.nn.Module):
r"""A node embedding storage module with asynchronous I/O support between
devices."""
def __init__(self, num_embeddings: int, embedding_dim: int, device=None):
super(History, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
pin_memory = device is None or str(device) == 'cpu'
self.emb = torch.empty(num_embeddings, embedding_dim, device=device,
pin_memory=pin_memory)
self._device = torch.device('cpu')
self.reset_parameters()
def reset_parameters(self):
self.emb.fill_(0)
def _apply(self, fn):
self._device = fn(torch.zeros(1)).device
return self
@torch.no_grad()
def pull(self, index: Optional[Tensor] = None) -> Tensor:
out = self.emb
if index is not None:
assert index.device == self.emb.device
out = out.index_select(0, index)
return out.to(device=self._device)
@torch.no_grad()
def push(self, x, index: Optional[Tensor] = None,
offset: Optional[Tensor] = None, count: Optional[Tensor] = None):
if index is None and x.size(0) != self.num_embeddings:
raise ValueError
elif index is None and x.size(0) == self.num_embeddings:
self.emb.copy_(x)
elif index is not None and (offset is None or count is None):
assert index.device == self.emb.device
self.emb[index] = x.to(self.emb.device)
else:
x_o = 0
x = x.to(self.emb.device)
for o, c, in zip(offset.tolist(), count.tolist()):
self.emb[o:o + c] = x[x_o:x_o + c]
x_o += c
def forward(self, *args, **kwargs):
""""""
raise NotImplementedError
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.num_embeddings}, '
f'{self.embedding_dim}, emb_device={self.emb.device}, '
f'device={self._device})')
from typing import Optional, Union, Tuple, NamedTuple, List
import time
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torch_sparse import SparseTensor
from torch_geometric.data import Data
relabel_fn = torch.ops.scaling_gnns.relabel_one_hop
class SubData(NamedTuple):
data: Union[Data, SparseTensor]
batch_size: int
n_id: Tensor
offset: Optional[Tensor]
count: Optional[Tensor]
def to(self, *args, **kwargs):
return SubData(self.data.to(*args, **kwargs), self.batch_size,
self.n_id, self.offset, self.count)
class SubgraphLoader(DataLoader):
r"""A simple subgraph loader that, given a randomly sampled or
pre-partioned batch of nodes, returns the subgraph of this batch
(including its 1-hop neighbors)."""
def __init__(
self,
data: Union[Data, SparseTensor],
ptr: Optional[Tensor] = None,
batch_size: int = 1,
bipartite: bool = True,
log: bool = True,
**kwargs,
):
self.__data__ = None if isinstance(data, SparseTensor) else data
self.__adj_t__ = data if isinstance(data, SparseTensor) else data.adj_t
self.__N__ = self.__adj_t__.size(1)
self.__E__ = self.__adj_t__.nnz()
self.__ptr__ = ptr
self.__bipartite__ = bipartite
if ptr is not None:
n_id = torch.arange(self.__N__)
batches = n_id.split((ptr[1:] - ptr[:-1]).tolist())
batches = [(i, batches[i]) for i in range(len(batches))]
if batch_size > 1:
super(SubgraphLoader,
self).__init__(batches,
collate_fn=self.sample_partitions,
batch_size=batch_size, **kwargs)
else:
if log:
t = time.perf_counter()
print('Pre-processing subgraphs...', end=' ', flush=True)
data_list = [
data for data in DataLoader(
batches, collate_fn=self.sample_partitions,
batch_size=batch_size, **kwargs)
]
if log:
print(f'Done! [{time.perf_counter() - t:.2f}s]')
super(SubgraphLoader,
self).__init__(data_list, batch_size=1,
collate_fn=lambda x: x[0], **kwargs)
else:
super(SubgraphLoader,
self).__init__(range(self.__N__),
collate_fn=self.sample_nodes,
batch_size=batch_size, **kwargs)
def sample_partitions(self, batches: List[Tuple[int, Tensor]]) -> SubData:
ptr_ids, n_ids = zip(*batches)
n_id = torch.cat(n_ids, dim=0)
batch_size = n_id.numel()
ptr_id = torch.tensor(ptr_ids)
offset = self.__ptr__[ptr_id]
count = self.__ptr__[ptr_id.add_(1)].sub_(offset)
rowptr, col, value = self.__adj_t__.csr()
rowptr, col, value, n_id = relabel_fn(rowptr, col, value, n_id,
self.__bipartite__)
adj_t = SparseTensor(rowptr=rowptr, col=col, value=value,
sparse_sizes=(rowptr.numel() - 1, n_id.numel()),
is_sorted=True)
if self.__data__ is None:
return SubData(adj_t, batch_size, n_id, offset, count)
data = self.__data__.__class__(adj_t=adj_t)
for key, item in self.__data__:
if isinstance(item, Tensor) and item.size(0) == self.__N__:
data[key] = item.index_select(0, n_id)
elif isinstance(item, SparseTensor):
pass
else:
data[key] = item
return SubData(data, batch_size, n_id, offset, count)
def sample_nodes(self, n_ids: List[int]) -> SubData:
n_id = torch.tensor(n_ids)
batch_size = n_id.numel()
rowptr, col, value = self.__adj_t__.csr()
rowptr, col, value, n_id = relabel_fn(rowptr, col, value, n_id,
self.__bipartite__)
adj_t = SparseTensor(rowptr=rowptr, col=col, value=value,
sparse_sizes=(rowptr.numel() - 1, n_id.numel()),
is_sorted=True)
if self.__data__ is None:
return SubData(adj_t, batch_size, n_id, None, None)
data = self.__data__.__class__(adj_t=adj_t)
for key, item in self.__data__:
if isinstance(item, Tensor) and item.size(0) == self.__N__:
data[key] = item.index_select(0, n_id)
elif isinstance(item, SparseTensor):
pass
else:
data[key] = item
return SubData(data, batch_size, n_id, None, None)
def __repr__(self):
return f'{self.__class__.__name__}()'
class EvalSubgraphLoader(SubgraphLoader):
def __init__(
self,
data: Union[Data, SparseTensor],
ptr: Optional[Tensor] = None,
batch_size: int = 1,
bipartite: bool = True,
log: bool = True,
**kwargs,
):
num_nodes = ptr[-1]
ptr = ptr[::batch_size]
if int(ptr[-1]) != int(num_nodes):
ptr = torch.cat([ptr, num_nodes.unsqueeze(0)], dim=0)
super(EvalSubgraphLoader,
self).__init__(data, ptr, 1, bipartite, log, num_workers=0,
shuffle=False, **kwargs)
import time
import copy
from typing import Union, Tuple
import torch
from torch import Tensor
from torch_sparse import SparseTensor
from torch_geometric.data import Data
partition_fn = torch.ops.torch_sparse.partition
def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
log: bool = True) -> Tuple[Tensor, Tensor]:
if log:
t = time.perf_counter()
print(f'Computing METIS partitioning with {num_parts} parts...',
end=' ', flush=True)
num_nodes = adj_t.size(0)
if num_parts <= 1:
perm, ptr = torch.arange(num_nodes), torch.tensor([0, num_nodes])
else:
rowptr, col, _ = adj_t.csr()
cluster = partition_fn(rowptr, col, None, num_parts, recursive)
cluster, perm = cluster.sort()
ptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)
if log:
print(f'Done! [{time.perf_counter() - t:.2f}s]')
return perm, ptr
def permute(data: Union[Data, SparseTensor], perm: Tensor,
log: bool = True) -> Union[Data, SparseTensor]:
if log:
t = time.perf_counter()
print('Permuting data...', end=' ', flush=True)
if isinstance(data, Data):
data = copy.copy(data)
for key, item in data:
if isinstance(item, Tensor) and item.size(0) == data.num_nodes:
data[key] = item[perm]
if isinstance(item, Tensor) and item.size(0) == data.num_edges:
raise NotImplementedError
if isinstance(item, SparseTensor):
data[key] = permute(item, perm, log=False)
else:
data = data.permute(perm)
if log:
print(f'Done! [{time.perf_counter() - t:.2f}s]')
return data
from .base import HistoryGNN
from .gcn import GCN
from .sage import SAGE
from .gat import GAT
from .appnp import APPNP
from .gcn2 import GCN2
from .gin import GIN
from .transformer import Transformer
from .pna import PNA
from .pna_jk import PNA_JK
__all__ = [
'HistoryGNN',
'GCN',
'SAGE',
'GAT',
'APPNP',
'GCN2',
'GIN',
'Transformer',
'PNA',
'PNA_JK',
]
from typing import Optional
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import ModuleList, Linear
from torch_sparse import SparseTensor
from .base import HistoryGNN
class APPNP(HistoryGNN):
def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
out_channels: int, num_layers: int, alpha: float,
dropout: float = 0.0, device=None, dtype=None):
super(APPNP, self).__init__(num_nodes, out_channels, num_layers,
device, dtype)
self.in_channels = in_channels
self.out_channels = out_channels
self.alpha = alpha
self.dropout = dropout
self.lins = ModuleList()
self.lins.append(Linear(in_channels, hidden_channels))
self.lins.append(Linear(hidden_channels, out_channels))
self.reg_modules = self.lins[:1]
self.nonreg_modules = self.lins[1:]
def reset_parameters(self):
super(APPNP, self).reset_parameters()
for lin in self.lins:
lin.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None) -> Tensor:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[0](x)
x = x.relu()
x = F.dropout(x, p=self.dropout, training=self.training)
x = x_0 = self.lins[1](x)
for history in self.histories:
x = (1 - self.alpha) * (adj_t @ x) + self.alpha * x_0
x = self.push_and_pull(history, x, batch_size, n_id)
x = (1 - self.alpha) * (adj_t @ x) + self.alpha * x_0
if batch_size is not None:
x = x[:batch_size]
return x
@torch.no_grad()
def mini_inference(self, x: Tensor, loader) -> Tensor:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[0](x)
x = x.relu()
x = F.dropout(x, p=self.dropout, training=self.training)
x = x_0 = self.lins[1](x)
for history in self.histories:
for info in loader:
info = info.to(self.device)
batch_size, n_id, adj_t, e_id = info
h = x[n_id]
h_0 = x_0[n_id]
h = (1 - self.alpha) * (adj_t @ h) + self.alpha * h_0
history.push_(h[:batch_size], n_id[:batch_size])
x = history.pull()
out = x.new_empty(self.num_nodes, self.out_channels)
for info in loader:
info = info.to(self.device)
batch_size, n_id, adj_t, e_id = info
h = x[n_id]
h_0 = x_0[n_id]
h = (1 - self.alpha) * (adj_t @ h) + self.alpha * h_0
out[n_id[:batch_size]] = h
return out
from typing import Optional, Callable
import warnings
import torch
from torch import Tensor
from torch_sparse import SparseTensor
from scaling_gnns.history2 import History
from scaling_gnns.pool import AsyncIOPool
class ScalableGNN(torch.nn.Module):
def __init__(self, num_nodes: int, hidden_channels: int, num_layers: int,
pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(ScalableGNN, self).__init__()
self.num_nodes = num_nodes
self.hidden_channels = hidden_channels
self.num_layers = num_layers
self.pool_size = num_layers if pool_size is None else pool_size
self.buffer_size = buffer_size
self.histories = torch.nn.ModuleList([
History(num_nodes, hidden_channels, device)
for _ in range(num_layers - 1)
])
self.pool = None
self._async = False
self.__out__ = None
@property
def emb_device(self):
return self.histories[0].emb.device
@property
def device(self):
return self.histories[0]._device
@property
def _out(self):
if self.__out__ is None:
self.__out__ = torch.empty(self.num_nodes, self.out_channels,
pin_memory=True)
return self.__out__
def _apply(self, fn: Callable) -> None:
super(ScalableGNN, self)._apply(fn)
if (str(self.emb_device) == 'cpu' and str(self.device)[:4] == 'cuda'
and self.pool_size is not None
and self.buffer_size is not None):
self.pool = AsyncIOPool(self.pool_size, self.buffer_size,
self.histories[0].embedding_dim)
self.pool.to(self.device)
return self
def reset_parameters(self):
for history in self.histories:
history.reset_parameters()
def __call__(self, x: Optional[Tensor] = None,
adj_t: Optional[SparseTensor] = None,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None,
offset: Optional[Tensor] = None,
count: Optional[Tensor] = None, loader=None,
**kwargs) -> Tensor:
if loader is not None:
return self.mini_inference(loader)
self._async = (self.pool is not None and batch_size is not None
and n_id is not None and offset is not None
and count is not None)
if batch_size is not None and not self._async:
warnings.warn('Asynchronous I/O disabled, although history and '
'model sit on different devices.')
if self._async:
for hist in self.histories:
self.pool.async_pull(hist.emb, None, None, n_id[batch_size:])
out = self.forward(x=x, adj_t=adj_t, batch_size=batch_size, n_id=n_id,
offset=offset, count=count, **kwargs)
if self._async:
for hist in self.histories:
self.pool.synchronize_push()
self._async = False
return out
def push_and_pull(self, history, x: Tensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None,
offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
if n_id is None and x.size(0) != self.num_nodes:
return x # Do nothing...
if n_id is None and x.size(0) == self.num_nodes:
history.push(x)
return x
assert n_id is not None
if batch_size is None:
history.push(x, n_id)
return x
if not self._async:
history.push(x[:batch_size], n_id[:batch_size], offset, count)
h = history.pull(n_id[batch_size:])
return torch.cat([x[:batch_size], h], dim=0)
out = self.pool.synchronize_pull()[:n_id.numel() - batch_size]
self.pool.async_push(x[:batch_size], offset, count, history.emb)
out = torch.cat([x[:batch_size], out], dim=0)
self.pool.free_pull()
return out
@torch.no_grad()
def mini_inference(self, loader) -> Tensor:
loader = [data + ({}, ) for data in loader]
for batch, batch_size, n_id, offset, count, state in loader:
x = batch.x.to(self.device)
adj_t = batch.adj_t.to(self.device)
out = self.forward_layer(0, x, adj_t, state)[:batch_size]
self.pool.async_push(out, offset, count, self.histories[0].emb)
self.pool.synchronize_push()
for i in range(1, len(self.histories)):
for _, batch_size, n_id, offset, count, _ in loader:
self.pool.async_pull(self.histories[i - 1].emb, offset, count,
n_id[batch_size:])
for batch, batch_size, n_id, offset, count, state in loader:
adj_t = batch.adj_t.to(self.device)
x = self.pool.synchronize_pull()[:n_id.numel()]
out = self.forward_layer(i, x, adj_t, state)[:batch_size]
self.pool.async_push(out, offset, count, self.histories[i].emb)
self.pool.free_pull()
self.pool.synchronize_push()
for _, batch_size, n_id, offset, count, _ in loader:
self.pool.async_pull(self.histories[-1].emb, offset, count,
n_id[batch_size:])
for batch, batch_size, n_id, offset, count, state in loader:
adj_t = batch.adj_t.to(self.device)
x = self.pool.synchronize_pull()[:n_id.numel()]
out = self.forward_layer(self.num_layers - 1, x, adj_t,
state)[:batch_size]
self.pool.async_push(out, offset, count, self._out)
self.pool.free_pull()
self.pool.synchronize_push()
return self._out
from typing import Optional
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Linear, ModuleList
from torch_sparse import SparseTensor
from torch_geometric.nn import GATConv
from .base import HistoryGNN
class GAT(HistoryGNN):
def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
hidden_heads: int, out_channels: int, out_heads: int,
num_layers: int, residual: bool = False, dropout: float = 0.0,
device=None, dtype=None):
super(GAT, self).__init__(num_nodes, hidden_channels * hidden_heads,
num_layers, device, dtype)
self.in_channels = in_channels
self.hidden_heads = hidden_heads
self.out_channels = out_channels
self.out_heads = out_heads
self.residual = residual
self.dropout = dropout
self.convs = ModuleList()
for i in range(num_layers - 1):
in_dim = in_channels if i == 0 else hidden_channels * hidden_heads
conv = GATConv(in_dim, hidden_channels, hidden_heads, concat=True,
dropout=dropout, add_self_loops=False)
self.convs.append(conv)
conv = GATConv(hidden_channels * hidden_heads, out_channels, out_heads,
concat=False, dropout=dropout, add_self_loops=False)
self.convs.append(conv)
self.lins = ModuleList()
if residual:
self.lins.append(
Linear(in_channels, hidden_channels * hidden_heads))
self.lins.append(
Linear(hidden_channels * hidden_heads, out_channels))
self.reg_modules = ModuleList([self.convs, self.lins])
self.nonreg_modules = ModuleList()
def reset_parameters(self):
super(GAT, self).reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for lin in self.lins:
lin.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None) -> Tensor:
for conv, history in zip(self.convs[:-1], self.histories):
h = F.dropout(x, p=self.dropout, training=self.training)
h = conv(h, adj_t)
if self.residual:
x = F.dropout(x, p=self.dropout, training=self.training)
h = h + x if h.size(-1) == x.size(-1) else h + self.lins[0](x)
x = F.elu(h)
x = self.push_and_pull(history, x, batch_size, n_id)
h = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[-1](h, adj_t)
if self.residual:
x = F.dropout(x, p=self.dropout, training=self.training)
h = h + self.lins[1](x)
if batch_size is not None:
h = h[:batch_size]
return h
@torch.no_grad()
def mini_inference(self, x: Tensor, loader) -> Tensor:
for conv, history in zip(self.convs[:-1], self.histories):
for info in loader:
info = info.to(self.device)
batch_size, n_id, adj_t, e_id = info
r = x[n_id]
h = conv(r, adj_t)
if self.residual:
if h.size(-1) == r.size(-1):
h = h + r
else:
h = h + self.lins[0](r)
h = F.elu(h)
history.push_(h[:batch_size], n_id[:batch_size])
x = history.pull()
out = x.new_empty(self.num_nodes, self.out_channels)
for info in loader:
info = info.to(self.device)
batch_size, n_id, adj_t, e_id = info
r = x[n_id]
h = self.convs[-1](r, adj_t)[:batch_size]
if self.residual:
h = h + self.lins[1](r)
out[n_id[:batch_size]] = h
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment