Commit 2f25da6c authored by rusty1s's avatar rusty1s
Browse files

initial commit

parent ac165af3
from typing import Optional
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import ModuleList, BatchNorm1d
from torch_sparse import SparseTensor
from torch_geometric.nn import GCNConv
from scaling_gnns.models.base2 import ScalableGNN
class GCN(ScalableGNN):
def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
out_channels: int, num_layers: int, dropout: float = 0.0,
drop_input: bool = True, batch_norm: bool = False,
residual: bool = False, pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(GCN, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size, buffer_size, device)
self.in_channels = in_channels
self.out_channels = out_channels
self.dropout = dropout
self.drop_input = drop_input
self.batch_norm = batch_norm
self.residual = residual
self.convs = ModuleList()
for i in range(num_layers):
in_dim = in_channels if i == 0 else hidden_channels
out_dim = out_channels if i == num_layers - 1 else hidden_channels
conv = GCNConv(in_dim, out_dim, normalize=False)
self.convs.append(conv)
self.bns = ModuleList()
for i in range(num_layers - 1):
bn = BatchNorm1d(hidden_channels)
self.bns.append(bn)
@property
def reg_modules(self):
return ModuleList(list(self.convs[:-1]) + list(self.bns))
@property
def nonreg_modules(self):
return self.convs[-1:]
def reset_parameters(self):
super(GCN, self).reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
for conv, bn, hist in zip(self.convs[:-1], self.bns, self.histories):
h = conv(x, adj_t)
if self.batch_norm:
h = bn(h)
if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)]
x = h.relu_()
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, adj_t)
return x
@torch.no_grad()
def forward_layer(self, layer, x, adj_t, state):
if layer == 0 and self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
else:
x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[layer](x, adj_t)
if layer < self.num_layers - 1:
if self.batch_norm:
h = self.bns[layer](h)
if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)]
h = h.relu_()
return h
from typing import Optional
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import ModuleList, Linear, BatchNorm1d
from torch_sparse import SparseTensor
from torch_geometric.nn import GCN2Conv
from scaling_gnns.models.base2 import ScalableGNN
class GCN2(ScalableGNN):
def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
out_channels: int, num_layers: int, alpha: float,
theta: float, shared_weights: bool = True,
dropout: float = 0.0, drop_input: bool = True,
batch_norm: bool = False, residual: bool = False,
pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(GCN2, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size, buffer_size, device)
self.in_channels = in_channels
self.out_channels = out_channels
self.dropout = dropout
self.drop_input = drop_input
self.batch_norm = batch_norm
self.residual = residual
self.lins = ModuleList()
self.lins.append(Linear(in_channels, hidden_channels))
self.lins.append(Linear(hidden_channels, out_channels))
self.convs = ModuleList()
for i in range(num_layers):
conv = GCN2Conv(hidden_channels, alpha=alpha, theta=theta,
layer=i + 1, shared_weights=shared_weights,
normalize=False)
self.convs.append(conv)
self.bns = ModuleList()
for i in range(num_layers):
bn = BatchNorm1d(hidden_channels)
self.bns.append(bn)
@property
def reg_modules(self):
return ModuleList(list(self.convs) + list(self.bns))
@property
def nonreg_modules(self):
return self.lins
def reset_parameters(self):
super(GCN2, self).reset_parameters()
for lin in self.lins:
lin.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
x = x_0 = self.lins[0](x).relu_()
x = F.dropout(x, p=self.dropout, training=self.training)
for conv, bn, hist in zip(self.convs[:-1], self.bns[:-1],
self.histories):
h = conv(x, x_0, adj_t)
if self.batch_norm:
h = bn(h)
if self.residual:
h += x[:h.size(0)]
x = h.relu_()
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count)
x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[-1](x, x_0, adj_t)
if self.batch_norm:
h = self.bns[-1](h)
if self.residual:
h += x[:h.size(0)]
x = h.relu_()
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[1](x)
return x
@torch.no_grad()
def forward_layer(self, layer, x, adj_t, state):
if layer == 0:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
x = x_0 = self.lins[0](x).relu_()
state['x_0'] = x_0[:adj_t.size(0)]
x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[layer](x, state['x_0'], adj_t)
if self.batch_norm:
h = self.bns[layer](h)
if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)]
x = h.relu_()
if layer == self.num_layers - 1:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[1](x)
return x
from typing import Optional
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import ModuleList, Identity
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU
from torch_sparse import SparseTensor
from torch_geometric.nn import GINConv
from torch_geometric.nn.inits import reset
from .base import HistoryGNN
class GIN(HistoryGNN):
def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
out_channels: int, num_layers: int, residual: bool = False,
dropout: float = 0.0, device=None, dtype=None):
super(GIN, self).__init__(num_nodes, hidden_channels, num_layers,
device, dtype)
self.in_channels = in_channels
self.out_channels = out_channels
self.residual = residual
self.dropout = dropout
self.lins = ModuleList()
self.lins.append(Linear(in_channels, hidden_channels))
self.lins.append(Linear(hidden_channels, out_channels))
self.convs = ModuleList()
for _ in range(num_layers):
conv = GINConv(nn=Identity(), train_eps=True)
self.convs.append(conv)
self.post_nns = ModuleList()
for i in range(num_layers):
post_nn = Sequential(
Linear(hidden_channels, hidden_channels),
BatchNorm1d(hidden_channels, track_running_stats=False),
ReLU(inplace=True),
Linear(hidden_channels, hidden_channels),
ReLU(inplace=True),
)
self.post_nns.append(post_nn)
def reset_parameters(self):
super(GIN, self).reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for post_nn in self.post_nns:
reset(post_nn)
for lin in self.lins:
lin.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None) -> Tensor:
x = self.lins[0](x).relu()
for conv, post_nn, history in zip(self.convs[:-1], self.post_nns[:-1],
self.histories):
if batch_size is not None:
h = torch.zeros_like(x)
h[:batch_size] = post_nn(conv(x, adj_t)[:batch_size])
else:
h = post_nn(conv(x, adj_t))
x = h.add_(x) if self.residual else h
x = self.push_and_pull(history, x, batch_size, n_id)
x = F.dropout(x, p=self.dropout, training=self.training)
if batch_size is not None:
h = self.post_nns[-1](self.convs[-1](x, adj_t)[:batch_size])
x = x[:batch_size]
else:
h = self.post_nns[-1](self.convs[-1](x, adj_t))
x = h.add_(x) if self.residual else h
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[1](x)
return x
from itertools import product
from typing import Optional, List
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import ModuleList, Linear, BatchNorm1d
from torch_sparse import SparseTensor
from torch_geometric.nn import MessagePassing
from scaling_gnns.models.base2 import ScalableGNN
EPS = 1e-5
class PNAConv(MessagePassing):
def __init__(self, in_channels: int, out_channels: int,
aggregators: List[str], scalers: List[str], deg: Tensor,
**kwargs):
super(PNAConv, self).__init__(aggr=None, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.aggregators = aggregators
self.scalers = scalers
deg = deg.to(torch.float)
self.avg_deg = {
'lin': deg.mean().item(),
'log': (deg + 1).log().mean().item(),
}
self.pre_lins = torch.nn.ModuleList([
Linear(in_channels, out_channels)
for _ in range(len(aggregators) * len(scalers))
])
self.post_lins = torch.nn.ModuleList([
Linear(out_channels, out_channels)
for _ in range(len(aggregators) * len(scalers))
])
self.lin = Linear(in_channels, out_channels)
self.reset_parameters()
def reset_parameters(self):
for lin in self.pre_lins:
lin.reset_parameters()
for lin in self.post_lins:
lin.reset_parameters()
self.lin.reset_parameters()
def forward(self, x: Tensor, adj_t):
out = self.propagate(adj_t, x=x)
out += self.lin(x)[:out.size(0)]
return out
def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
deg = adj_t.storage.rowcount().to(x.dtype).view(-1, 1)
out = 0
for (aggr, scaler), pre_lin, post_lin in zip(
product(self.aggregators, self.scalers), self.pre_lins,
self.post_lins):
h = pre_lin(x).relu_()
h = adj_t.matmul(h, reduce=aggr)
h = post_lin(h)
if scaler == 'amplification':
h *= (deg + 1).log() / self.avg_deg['log']
elif scaler == 'attenuation':
h *= self.avg_deg['log'] / ((deg + 1).log() + EPS)
out += h
return out
class PNA(ScalableGNN):
def __init__(self, num_nodes: int, in_channels: int, hidden_channels: int,
out_channels: int, num_layers: int, aggregators: List[int],
scalers: List[int], deg: Tensor, dropout: float = 0.0,
drop_input: bool = True, batch_norm: bool = False,
residual: bool = False, pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(PNA, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size, buffer_size, device)
self.in_channels = in_channels
self.out_channels = out_channels
self.dropout = dropout
self.drop_input = drop_input
self.batch_norm = batch_norm
self.residual = residual
self.convs = ModuleList()
for i in range(num_layers):
in_dim = in_channels if i == 0 else hidden_channels
out_dim = out_channels if i == num_layers - 1 else hidden_channels
conv = PNAConv(in_dim, out_dim, aggregators=aggregators,
scalers=scalers, deg=deg)
self.convs.append(conv)
self.bns = ModuleList()
for i in range(num_layers - 1):
bn = BatchNorm1d(hidden_channels)
self.bns.append(bn)
@property
def reg_modules(self):
return ModuleList(list(self.convs[:-1]) + list(self.bns))
@property
def nonreg_modules(self):
return self.convs[-1:]
def reset_parameters(self):
super(PNA, self).reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
for conv, bn, hist in zip(self.convs[:-1], self.bns, self.histories):
h = conv(x, adj_t)
if self.batch_norm:
h = bn(h)
if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)]
x = h.relu_()
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, adj_t)
return x
@torch.no_grad()
def forward_layer(self, layer, x, adj_t, state):
if layer == 0 and self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[layer](x, adj_t)
if layer < self.num_layers - 1:
if self.batch_norm:
h = self.bns[layer](h)
if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)]
h = h.relu_()
h = F.dropout(h, p=self.dropout, training=self.training)
return h
from typing import Optional, List
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import (ModuleList, Linear, BatchNorm1d, Sequential, ReLU,
Identity)
from torch_sparse import SparseTensor
from scaling_gnns.models.base2 import ScalableGNN
from scaling_gnns.models.pna import PNAConv
class PNA_JK(ScalableGNN):
def __init__(self, num_nodes: int, in_channels: int, hidden_channels: int,
out_channels: int, num_layers: int, aggregators: List[int],
scalers: List[int], deg: Tensor, dropout: float = 0.0,
drop_input: bool = True, batch_norm: bool = False,
residual: bool = False, pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(PNA_JK, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size, buffer_size, device)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_layers == num_layers
self.dropout = dropout
self.drop_input = drop_input
self.batch_norm = batch_norm
self.residual = residual
self.lins = ModuleList()
self.lins.append(
Sequential(
Linear(in_channels, hidden_channels),
BatchNorm1d(hidden_channels) if batch_norm else Identity(),
ReLU(inplace=True),
))
self.lins.append(
Linear((num_layers + 1) * hidden_channels, out_channels))
self.convs = ModuleList()
for _ in range(num_layers):
conv = PNAConv(hidden_channels, hidden_channels,
aggregators=aggregators, scalers=scalers, deg=deg)
self.convs.append(conv)
self.bns = ModuleList()
for _ in range(num_layers):
bn = BatchNorm1d(hidden_channels)
self.bns.append(bn)
@property
def reg_modules(self):
return ModuleList(list(self.convs) + list(self.bns))
@property
def nonreg_modules(self):
return self.lins
def reset_parameters(self):
super(PNA_JK, self).reset_parameters()
for lin in self.lins:
lin.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[0](x)
xs = [x[:adj_t.size(0)]]
for conv, bn, hist in zip(self.convs[:-1], self.bns[:-1],
self.histories):
h = conv(x, adj_t)
if self.batch_norm:
h = bn(h)
if self.residual:
h += x[:h.size(0)]
x = h.relu_()
xs += [x]
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count)
x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[-1](x, adj_t)
if self.batch_norm:
h = self.bns[-1](h)
if self.residual:
h += x[:h.size(0)]
x = h.relu_()
xs += [x]
x = torch.cat(xs, dim=-1)
x = F.dropout(x, p=self.dropout, training=self.training)
return self.lins[1](x)
@torch.no_grad()
def forward_layer(self, layer, x, adj_t, state):
if layer == 0:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[0](x)
state['xs'] = [x[:adj_t.size(0)]]
h = self.convs[layer](x, adj_t)
if self.batch_norm:
h = self.bns[layer](h)
if self.residual:
h += x[:h.size(0)]
h = h.relu_()
state['xs'] += [h]
h = F.dropout(h, p=self.dropout, training=self.training)
if layer == self.num_layers - 1:
h = torch.cat(state['xs'], dim=-1)
h = F.dropout(h, p=self.dropout, training=self.training)
h = self.lins[1](h)
return h
from typing import Optional, Callable
import torch
from torch import Tensor
from torch.cuda import Stream
synchronize = torch.ops.scaling_gnns.synchronize
read_async = torch.ops.scaling_gnns.read_async
write_async = torch.ops.scaling_gnns.write_async
class AsyncIOPool(torch.nn.Module):
def __init__(self, pool_size: int, buffer_size: int, embedding_dim: int):
super(AsyncIOPool, self).__init__()
self.pool_size = pool_size
self.embedding_dim = embedding_dim
self.buffer_size = buffer_size
self._device = torch.device('cpu')
self._pull_queue = []
self._push_cache = [None] * pool_size
self._push_streams = [None] * pool_size
self._pull_streams = [None] * pool_size
self._cpu_buffers = [None] * pool_size
self._cuda_buffers = [None] * pool_size
self._pull_index = -1
self._push_index = -1
def _apply(self, fn: Callable) -> None:
self._device = fn(torch.zeros(1)).device
return self
def _pull_stream(self, idx: int) -> Stream:
if self._pull_streams[idx] is None:
assert str(self._device)[:4] == 'cuda'
self._pull_streams[idx] = torch.cuda.Stream(self._device)
return self._pull_streams[idx]
def _push_stream(self, idx: int) -> Stream:
if self._push_streams[idx] is None:
assert str(self._device)[:4] == 'cuda'
self._push_streams[idx] = torch.cuda.Stream(self._device)
return self._push_streams[idx]
def _cpu_buffer(self, idx: int) -> Tensor:
if self._cpu_buffers[idx] is None:
self._cpu_buffers[idx] = torch.empty(self.buffer_size,
self.embedding_dim,
pin_memory=True)
return self._cpu_buffers[idx]
def _cuda_buffer(self, idx: int) -> Tensor:
if self._cuda_buffers[idx] is None:
assert str(self._device)[:4] == 'cuda'
self._cuda_buffers[idx] = torch.empty(self.buffer_size,
self.embedding_dim,
device=self._device)
return self._cuda_buffers[idx]
@torch.no_grad()
def async_pull(self, src: Tensor, offset: Optional[Tensor],
count: Optional[Tensor], index: Tensor) -> None:
self._pull_index = (self._pull_index + 1) % self.pool_size
data = (self._pull_index, src, offset, count, index)
self._pull_queue.append(data)
if len(self._pull_queue) <= self.pool_size:
self._async_pull(self._pull_index, src, offset, count, index)
@torch.no_grad()
def _async_pull(self, idx: int, src: Tensor, offset: Optional[Tensor],
count: Optional[Tensor], index: Tensor) -> None:
with torch.cuda.stream(self._pull_stream(idx)):
read_async(src, offset, count, index, self._cuda_buffer(idx),
self._cpu_buffer(idx))
@torch.no_grad()
def synchronize_pull(self) -> Tensor:
idx = self._pull_queue[0][0]
synchronize()
torch.cuda.synchronize(self._pull_stream(idx))
return self._cuda_buffer(idx)
@torch.no_grad()
def free_pull(self) -> None:
self._pull_queue.pop(0)
if len(self._pull_queue) >= self.pool_size:
data = self._pull_queue[self.pool_size - 1]
idx, src, offset, count, index = data
self._async_pull(idx, src, offset, count, index)
if len(self._pull_queue) == 0:
self._pull_index = -1
@torch.no_grad()
def async_push(self, src: Tensor, offset: Tensor, count: Tensor,
dst: Tensor) -> None:
self._push_index = (self._push_index + 1) % self.pool_size
self.synchronize_push(self._push_index)
self._push_cache[self._push_index] = src
with torch.cuda.stream(self._push_stream(self._push_index)):
write_async(src, offset, count, dst)
@torch.no_grad()
def synchronize_push(self, idx: Optional[int] = None) -> None:
if idx is None:
for idx in range(self.pool_size):
self.synchronize_push(idx)
self._push_index = -1
else:
torch.cuda.synchronize(self._push_stream(idx))
self._push_cache[idx] = None
def forward(self, *args, **kwargs):
""""""
raise NotImplementedError
def __repr__(self):
return (f'{self.__class__.__name__}(pool_size={self.pool_size}, '
f'buffer_size={self.buffer_size}, '
f'embedding_dim={self.embedding_dim}, '
f'device={self._device})')
from typing import Optional
import torch
from torch import Tensor
def index2mask(idx: Tensor, size: int) -> Tensor:
mask = torch.zeros(size, dtype=torch.bool, device=idx.device)
mask[idx] = True
return mask
def compute_acc(logits: Tensor, y: Tensor, mask: Optional[Tensor] = None):
if mask is not None:
logits, y = logits[mask], y[mask]
if y.dim() == 1:
return int(logits.argmax(dim=-1).eq(y).sum()) / y.size(0)
else:
y_pred = logits > 0
y_true = y > 0.5
tp = int((y_true & y_pred).sum())
fp = int((~y_true & y_pred).sum())
fn = int((y_true & ~y_pred).sum())
precision = tp / (tp + fp)
recall = tp / (tp + fn)
return 2 * (precision * recall) / (precision + recall)
def gen_masks(y: Tensor, train_per_class: int = 20, val_per_class: int = 30,
num_splits: int = 20):
num_classes = int(y.max()) + 1
train_mask = torch.zeros(y.size(0), num_splits, dtype=torch.bool)
val_mask = torch.zeros(y.size(0), num_splits, dtype=torch.bool)
for c in range(num_classes):
idx = (y == c).nonzero(as_tuple=False).view(-1)
perm = torch.stack(
[torch.randperm(idx.size(0)) for _ in range(num_splits)], dim=1)
idx = idx[perm]
train_idx = idx[:train_per_class]
train_mask.scatter_(0, train_idx, True)
val_idx = idx[train_per_class:train_per_class + val_per_class]
val_mask.scatter_(0, val_idx, True)
test_mask = ~(train_mask | val_mask)
return train_mask, val_mask, test_mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment