"vscode:/vscode.git/clone" did not exist on "1ab6be1b2666eb77cc4f849e8bf7dfb7e1856f48"
Commit 2f25da6c authored by rusty1s's avatar rusty1s
Browse files

initial commit

parent ac165af3
from typing import Optional
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import ModuleList, BatchNorm1d
from torch_sparse import SparseTensor
from torch_geometric.nn import GCNConv
from scaling_gnns.models.base2 import ScalableGNN
class GCN(ScalableGNN):
def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
out_channels: int, num_layers: int, dropout: float = 0.0,
drop_input: bool = True, batch_norm: bool = False,
residual: bool = False, pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(GCN, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size, buffer_size, device)
self.in_channels = in_channels
self.out_channels = out_channels
self.dropout = dropout
self.drop_input = drop_input
self.batch_norm = batch_norm
self.residual = residual
self.convs = ModuleList()
for i in range(num_layers):
in_dim = in_channels if i == 0 else hidden_channels
out_dim = out_channels if i == num_layers - 1 else hidden_channels
conv = GCNConv(in_dim, out_dim, normalize=False)
self.convs.append(conv)
self.bns = ModuleList()
for i in range(num_layers - 1):
bn = BatchNorm1d(hidden_channels)
self.bns.append(bn)
@property
def reg_modules(self):
return ModuleList(list(self.convs[:-1]) + list(self.bns))
@property
def nonreg_modules(self):
return self.convs[-1:]
def reset_parameters(self):
super(GCN, self).reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
for conv, bn, hist in zip(self.convs[:-1], self.bns, self.histories):
h = conv(x, adj_t)
if self.batch_norm:
h = bn(h)
if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)]
x = h.relu_()
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, adj_t)
return x
@torch.no_grad()
def forward_layer(self, layer, x, adj_t, state):
if layer == 0 and self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
else:
x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[layer](x, adj_t)
if layer < self.num_layers - 1:
if self.batch_norm:
h = self.bns[layer](h)
if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)]
h = h.relu_()
return h
from typing import Optional
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import ModuleList, Linear, BatchNorm1d
from torch_sparse import SparseTensor
from torch_geometric.nn import GCN2Conv
from scaling_gnns.models.base2 import ScalableGNN
class GCN2(ScalableGNN):
def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
out_channels: int, num_layers: int, alpha: float,
theta: float, shared_weights: bool = True,
dropout: float = 0.0, drop_input: bool = True,
batch_norm: bool = False, residual: bool = False,
pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(GCN2, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size, buffer_size, device)
self.in_channels = in_channels
self.out_channels = out_channels
self.dropout = dropout
self.drop_input = drop_input
self.batch_norm = batch_norm
self.residual = residual
self.lins = ModuleList()
self.lins.append(Linear(in_channels, hidden_channels))
self.lins.append(Linear(hidden_channels, out_channels))
self.convs = ModuleList()
for i in range(num_layers):
conv = GCN2Conv(hidden_channels, alpha=alpha, theta=theta,
layer=i + 1, shared_weights=shared_weights,
normalize=False)
self.convs.append(conv)
self.bns = ModuleList()
for i in range(num_layers):
bn = BatchNorm1d(hidden_channels)
self.bns.append(bn)
@property
def reg_modules(self):
return ModuleList(list(self.convs) + list(self.bns))
@property
def nonreg_modules(self):
return self.lins
def reset_parameters(self):
super(GCN2, self).reset_parameters()
for lin in self.lins:
lin.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
x = x_0 = self.lins[0](x).relu_()
x = F.dropout(x, p=self.dropout, training=self.training)
for conv, bn, hist in zip(self.convs[:-1], self.bns[:-1],
self.histories):
h = conv(x, x_0, adj_t)
if self.batch_norm:
h = bn(h)
if self.residual:
h += x[:h.size(0)]
x = h.relu_()
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count)
x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[-1](x, x_0, adj_t)
if self.batch_norm:
h = self.bns[-1](h)
if self.residual:
h += x[:h.size(0)]
x = h.relu_()
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[1](x)
return x
@torch.no_grad()
def forward_layer(self, layer, x, adj_t, state):
if layer == 0:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
x = x_0 = self.lins[0](x).relu_()
state['x_0'] = x_0[:adj_t.size(0)]
x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[layer](x, state['x_0'], adj_t)
if self.batch_norm:
h = self.bns[layer](h)
if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)]
x = h.relu_()
if layer == self.num_layers - 1:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[1](x)
return x
from typing import Optional
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import ModuleList, Identity
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU
from torch_sparse import SparseTensor
from torch_geometric.nn import GINConv
from torch_geometric.nn.inits import reset
from .base import HistoryGNN
class GIN(HistoryGNN):
def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
out_channels: int, num_layers: int, residual: bool = False,
dropout: float = 0.0, device=None, dtype=None):
super(GIN, self).__init__(num_nodes, hidden_channels, num_layers,
device, dtype)
self.in_channels = in_channels
self.out_channels = out_channels
self.residual = residual
self.dropout = dropout
self.lins = ModuleList()
self.lins.append(Linear(in_channels, hidden_channels))
self.lins.append(Linear(hidden_channels, out_channels))
self.convs = ModuleList()
for _ in range(num_layers):
conv = GINConv(nn=Identity(), train_eps=True)
self.convs.append(conv)
self.post_nns = ModuleList()
for i in range(num_layers):
post_nn = Sequential(
Linear(hidden_channels, hidden_channels),
BatchNorm1d(hidden_channels, track_running_stats=False),
ReLU(inplace=True),
Linear(hidden_channels, hidden_channels),
ReLU(inplace=True),
)
self.post_nns.append(post_nn)
def reset_parameters(self):
super(GIN, self).reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for post_nn in self.post_nns:
reset(post_nn)
for lin in self.lins:
lin.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None) -> Tensor:
x = self.lins[0](x).relu()
for conv, post_nn, history in zip(self.convs[:-1], self.post_nns[:-1],
self.histories):
if batch_size is not None:
h = torch.zeros_like(x)
h[:batch_size] = post_nn(conv(x, adj_t)[:batch_size])
else:
h = post_nn(conv(x, adj_t))
x = h.add_(x) if self.residual else h
x = self.push_and_pull(history, x, batch_size, n_id)
x = F.dropout(x, p=self.dropout, training=self.training)
if batch_size is not None:
h = self.post_nns[-1](self.convs[-1](x, adj_t)[:batch_size])
x = x[:batch_size]
else:
h = self.post_nns[-1](self.convs[-1](x, adj_t))
x = h.add_(x) if self.residual else h
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[1](x)
return x
from itertools import product
from typing import Optional, List
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import ModuleList, Linear, BatchNorm1d
from torch_sparse import SparseTensor
from torch_geometric.nn import MessagePassing
from scaling_gnns.models.base2 import ScalableGNN
EPS = 1e-5
class PNAConv(MessagePassing):
def __init__(self, in_channels: int, out_channels: int,
aggregators: List[str], scalers: List[str], deg: Tensor,
**kwargs):
super(PNAConv, self).__init__(aggr=None, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.aggregators = aggregators
self.scalers = scalers
deg = deg.to(torch.float)
self.avg_deg = {
'lin': deg.mean().item(),
'log': (deg + 1).log().mean().item(),
}
self.pre_lins = torch.nn.ModuleList([
Linear(in_channels, out_channels)
for _ in range(len(aggregators) * len(scalers))
])
self.post_lins = torch.nn.ModuleList([
Linear(out_channels, out_channels)
for _ in range(len(aggregators) * len(scalers))
])
self.lin = Linear(in_channels, out_channels)
self.reset_parameters()
def reset_parameters(self):
for lin in self.pre_lins:
lin.reset_parameters()
for lin in self.post_lins:
lin.reset_parameters()
self.lin.reset_parameters()
def forward(self, x: Tensor, adj_t):
out = self.propagate(adj_t, x=x)
out += self.lin(x)[:out.size(0)]
return out
def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
deg = adj_t.storage.rowcount().to(x.dtype).view(-1, 1)
out = 0
for (aggr, scaler), pre_lin, post_lin in zip(
product(self.aggregators, self.scalers), self.pre_lins,
self.post_lins):
h = pre_lin(x).relu_()
h = adj_t.matmul(h, reduce=aggr)
h = post_lin(h)
if scaler == 'amplification':
h *= (deg + 1).log() / self.avg_deg['log']
elif scaler == 'attenuation':
h *= self.avg_deg['log'] / ((deg + 1).log() + EPS)
out += h
return out
class PNA(ScalableGNN):
def __init__(self, num_nodes: int, in_channels: int, hidden_channels: int,
out_channels: int, num_layers: int, aggregators: List[int],
scalers: List[int], deg: Tensor, dropout: float = 0.0,
drop_input: bool = True, batch_norm: bool = False,
residual: bool = False, pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(PNA, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size, buffer_size, device)
self.in_channels = in_channels
self.out_channels = out_channels
self.dropout = dropout
self.drop_input = drop_input
self.batch_norm = batch_norm
self.residual = residual
self.convs = ModuleList()
for i in range(num_layers):
in_dim = in_channels if i == 0 else hidden_channels
out_dim = out_channels if i == num_layers - 1 else hidden_channels
conv = PNAConv(in_dim, out_dim, aggregators=aggregators,
scalers=scalers, deg=deg)
self.convs.append(conv)
self.bns = ModuleList()
for i in range(num_layers - 1):
bn = BatchNorm1d(hidden_channels)
self.bns.append(bn)
@property
def reg_modules(self):
return ModuleList(list(self.convs[:-1]) + list(self.bns))
@property
def nonreg_modules(self):
return self.convs[-1:]
def reset_parameters(self):
super(PNA, self).reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
for conv, bn, hist in zip(self.convs[:-1], self.bns, self.histories):
h = conv(x, adj_t)
if self.batch_norm:
h = bn(h)
if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)]
x = h.relu_()
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, adj_t)
return x
@torch.no_grad()
def forward_layer(self, layer, x, adj_t, state):
if layer == 0 and self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[layer](x, adj_t)
if layer < self.num_layers - 1:
if self.batch_norm:
h = self.bns[layer](h)
if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)]
h = h.relu_()
h = F.dropout(h, p=self.dropout, training=self.training)
return h
from typing import Optional, List
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import (ModuleList, Linear, BatchNorm1d, Sequential, ReLU,
Identity)
from torch_sparse import SparseTensor
from scaling_gnns.models.base2 import ScalableGNN
from scaling_gnns.models.pna import PNAConv
class PNA_JK(ScalableGNN):
def __init__(self, num_nodes: int, in_channels: int, hidden_channels: int,
out_channels: int, num_layers: int, aggregators: List[int],
scalers: List[int], deg: Tensor, dropout: float = 0.0,
drop_input: bool = True, batch_norm: bool = False,
residual: bool = False, pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(PNA_JK, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size, buffer_size, device)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_layers == num_layers
self.dropout = dropout
self.drop_input = drop_input
self.batch_norm = batch_norm
self.residual = residual
self.lins = ModuleList()
self.lins.append(
Sequential(
Linear(in_channels, hidden_channels),
BatchNorm1d(hidden_channels) if batch_norm else Identity(),
ReLU(inplace=True),
))
self.lins.append(
Linear((num_layers + 1) * hidden_channels, out_channels))
self.convs = ModuleList()
for _ in range(num_layers):
conv = PNAConv(hidden_channels, hidden_channels,
aggregators=aggregators, scalers=scalers, deg=deg)
self.convs.append(conv)
self.bns = ModuleList()
for _ in range(num_layers):
bn = BatchNorm1d(hidden_channels)
self.bns.append(bn)
@property
def reg_modules(self):
return ModuleList(list(self.convs) + list(self.bns))
@property
def nonreg_modules(self):
return self.lins
def reset_parameters(self):
super(PNA_JK, self).reset_parameters()
for lin in self.lins:
lin.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[0](x)
xs = [x[:adj_t.size(0)]]
for conv, bn, hist in zip(self.convs[:-1], self.bns[:-1],
self.histories):
h = conv(x, adj_t)
if self.batch_norm:
h = bn(h)
if self.residual:
h += x[:h.size(0)]
x = h.relu_()
xs += [x]
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count)
x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[-1](x, adj_t)
if self.batch_norm:
h = self.bns[-1](h)
if self.residual:
h += x[:h.size(0)]
x = h.relu_()
xs += [x]
x = torch.cat(xs, dim=-1)
x = F.dropout(x, p=self.dropout, training=self.training)
return self.lins[1](x)
@torch.no_grad()
def forward_layer(self, layer, x, adj_t, state):
if layer == 0:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[0](x)
state['xs'] = [x[:adj_t.size(0)]]
h = self.convs[layer](x, adj_t)
if self.batch_norm:
h = self.bns[layer](h)
if self.residual:
h += x[:h.size(0)]
h = h.relu_()
state['xs'] += [h]
h = F.dropout(h, p=self.dropout, training=self.training)
if layer == self.num_layers - 1:
h = torch.cat(state['xs'], dim=-1)
h = F.dropout(h, p=self.dropout, training=self.training)
h = self.lins[1](h)
return h
from typing import Optional, Callable
import torch
from torch import Tensor
from torch.cuda import Stream
synchronize = torch.ops.scaling_gnns.synchronize
read_async = torch.ops.scaling_gnns.read_async
write_async = torch.ops.scaling_gnns.write_async
class AsyncIOPool(torch.nn.Module):
def __init__(self, pool_size: int, buffer_size: int, embedding_dim: int):
super(AsyncIOPool, self).__init__()
self.pool_size = pool_size
self.embedding_dim = embedding_dim
self.buffer_size = buffer_size
self._device = torch.device('cpu')
self._pull_queue = []
self._push_cache = [None] * pool_size
self._push_streams = [None] * pool_size
self._pull_streams = [None] * pool_size
self._cpu_buffers = [None] * pool_size
self._cuda_buffers = [None] * pool_size
self._pull_index = -1
self._push_index = -1
def _apply(self, fn: Callable) -> None:
self._device = fn(torch.zeros(1)).device
return self
def _pull_stream(self, idx: int) -> Stream:
if self._pull_streams[idx] is None:
assert str(self._device)[:4] == 'cuda'
self._pull_streams[idx] = torch.cuda.Stream(self._device)
return self._pull_streams[idx]
def _push_stream(self, idx: int) -> Stream:
if self._push_streams[idx] is None:
assert str(self._device)[:4] == 'cuda'
self._push_streams[idx] = torch.cuda.Stream(self._device)
return self._push_streams[idx]
def _cpu_buffer(self, idx: int) -> Tensor:
if self._cpu_buffers[idx] is None:
self._cpu_buffers[idx] = torch.empty(self.buffer_size,
self.embedding_dim,
pin_memory=True)
return self._cpu_buffers[idx]
def _cuda_buffer(self, idx: int) -> Tensor:
if self._cuda_buffers[idx] is None:
assert str(self._device)[:4] == 'cuda'
self._cuda_buffers[idx] = torch.empty(self.buffer_size,
self.embedding_dim,
device=self._device)
return self._cuda_buffers[idx]
@torch.no_grad()
def async_pull(self, src: Tensor, offset: Optional[Tensor],
count: Optional[Tensor], index: Tensor) -> None:
self._pull_index = (self._pull_index + 1) % self.pool_size
data = (self._pull_index, src, offset, count, index)
self._pull_queue.append(data)
if len(self._pull_queue) <= self.pool_size:
self._async_pull(self._pull_index, src, offset, count, index)
@torch.no_grad()
def _async_pull(self, idx: int, src: Tensor, offset: Optional[Tensor],
count: Optional[Tensor], index: Tensor) -> None:
with torch.cuda.stream(self._pull_stream(idx)):
read_async(src, offset, count, index, self._cuda_buffer(idx),
self._cpu_buffer(idx))
@torch.no_grad()
def synchronize_pull(self) -> Tensor:
idx = self._pull_queue[0][0]
synchronize()
torch.cuda.synchronize(self._pull_stream(idx))
return self._cuda_buffer(idx)
@torch.no_grad()
def free_pull(self) -> None:
self._pull_queue.pop(0)
if len(self._pull_queue) >= self.pool_size:
data = self._pull_queue[self.pool_size - 1]
idx, src, offset, count, index = data
self._async_pull(idx, src, offset, count, index)
if len(self._pull_queue) == 0:
self._pull_index = -1
@torch.no_grad()
def async_push(self, src: Tensor, offset: Tensor, count: Tensor,
dst: Tensor) -> None:
self._push_index = (self._push_index + 1) % self.pool_size
self.synchronize_push(self._push_index)
self._push_cache[self._push_index] = src
with torch.cuda.stream(self._push_stream(self._push_index)):
write_async(src, offset, count, dst)
@torch.no_grad()
def synchronize_push(self, idx: Optional[int] = None) -> None:
if idx is None:
for idx in range(self.pool_size):
self.synchronize_push(idx)
self._push_index = -1
else:
torch.cuda.synchronize(self._push_stream(idx))
self._push_cache[idx] = None
def forward(self, *args, **kwargs):
""""""
raise NotImplementedError
def __repr__(self):
return (f'{self.__class__.__name__}(pool_size={self.pool_size}, '
f'buffer_size={self.buffer_size}, '
f'embedding_dim={self.embedding_dim}, '
f'device={self._device})')
from typing import Optional
import torch
from torch import Tensor
def index2mask(idx: Tensor, size: int) -> Tensor:
mask = torch.zeros(size, dtype=torch.bool, device=idx.device)
mask[idx] = True
return mask
def compute_acc(logits: Tensor, y: Tensor, mask: Optional[Tensor] = None):
if mask is not None:
logits, y = logits[mask], y[mask]
if y.dim() == 1:
return int(logits.argmax(dim=-1).eq(y).sum()) / y.size(0)
else:
y_pred = logits > 0
y_true = y > 0.5
tp = int((y_true & y_pred).sum())
fp = int((~y_true & y_pred).sum())
fn = int((y_true & ~y_pred).sum())
precision = tp / (tp + fp)
recall = tp / (tp + fn)
return 2 * (precision * recall) / (precision + recall)
def gen_masks(y: Tensor, train_per_class: int = 20, val_per_class: int = 30,
num_splits: int = 20):
num_classes = int(y.max()) + 1
train_mask = torch.zeros(y.size(0), num_splits, dtype=torch.bool)
val_mask = torch.zeros(y.size(0), num_splits, dtype=torch.bool)
for c in range(num_classes):
idx = (y == c).nonzero(as_tuple=False).view(-1)
perm = torch.stack(
[torch.randperm(idx.size(0)) for _ in range(num_splits)], dim=1)
idx = idx[perm]
train_idx = idx[:train_per_class]
train_mask.scatter_(0, train_idx, True)
val_idx = idx[train_per_class:train_per_class + val_per_class]
val_mask.scatter_(0, val_idx, True)
test_mask = ~(train_mask | val_mask)
return train_mask, val_mask, test_mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment