Commit 38ca4fb2 authored by rusty1s's avatar rusty1s
Browse files

update super calls

parent 07932207
...@@ -10,7 +10,6 @@ params: ...@@ -10,7 +10,6 @@ params:
hidden_channels: 8 hidden_channels: 8
hidden_heads: 8 hidden_heads: 8
out_heads: 1 out_heads: 1
residual: false
dropout: 0.6 dropout: 0.6
num_parts: 40 num_parts: 40
batch_size: 10 batch_size: 10
...@@ -28,7 +27,6 @@ params: ...@@ -28,7 +27,6 @@ params:
hidden_channels: 8 hidden_channels: 8
hidden_heads: 8 hidden_heads: 8
out_heads: 1 out_heads: 1
residual: false
dropout: 0.6 dropout: 0.6
num_parts: 24 num_parts: 24
batch_size: 8 batch_size: 8
...@@ -46,7 +44,6 @@ params: ...@@ -46,7 +44,6 @@ params:
hidden_channels: 8 hidden_channels: 8
hidden_heads: 8 hidden_heads: 8
out_heads: 8 out_heads: 8
residual: false
dropout: 0.6 dropout: 0.6
num_parts: 4 num_parts: 4
batch_size: 1 batch_size: 1
...@@ -64,7 +61,6 @@ params: ...@@ -64,7 +61,6 @@ params:
hidden_channels: 8 hidden_channels: 8
hidden_heads: 8 hidden_heads: 8
out_heads: 1 out_heads: 1
residual: false
dropout: 0.6 dropout: 0.6
num_parts: 8 num_parts: 8
batch_size: 2 batch_size: 2
...@@ -82,7 +78,6 @@ params: ...@@ -82,7 +78,6 @@ params:
hidden_channels: 8 hidden_channels: 8
hidden_heads: 8 hidden_heads: 8
out_heads: 1 out_heads: 1
residual: false
dropout: 0.6 dropout: 0.6
num_parts: 4 num_parts: 4
batch_size: 1 batch_size: 1
...@@ -100,7 +95,6 @@ params: ...@@ -100,7 +95,6 @@ params:
hidden_channels: 14 hidden_channels: 14
hidden_heads: 5 hidden_heads: 5
out_heads: 1 out_heads: 1
residual: false
dropout: 0.5 dropout: 0.5
num_parts: 2 num_parts: 2
batch_size: 1 batch_size: 1
......
...@@ -14,8 +14,8 @@ class APPNP(ScalableGNN): ...@@ -14,8 +14,8 @@ class APPNP(ScalableGNN):
out_channels: int, num_layers: int, alpha: float, out_channels: int, num_layers: int, alpha: float,
dropout: float = 0.0, pool_size: Optional[int] = None, dropout: float = 0.0, pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None): buffer_size: Optional[int] = None, device=None):
super(APPNP, self).__init__(num_nodes, out_channels, num_layers, super().__init__(num_nodes, out_channels, num_layers, pool_size,
pool_size, buffer_size, device) buffer_size, device)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -30,15 +30,11 @@ class APPNP(ScalableGNN): ...@@ -30,15 +30,11 @@ class APPNP(ScalableGNN):
self.nonreg_modules = self.lins[1:] self.nonreg_modules = self.lins[1:]
def reset_parameters(self): def reset_parameters(self):
super(APPNP, self).reset_parameters() super().reset_parameters()
for lin in self.lins: for lin in self.lins:
lin.reset_parameters() lin.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor, def forward(self, x: Tensor, adj_t: SparseTensor, *args) -> Tensor:
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[0](x) x = self.lins[0](x)
x = x.relu() x = x.relu()
...@@ -48,7 +44,7 @@ class APPNP(ScalableGNN): ...@@ -48,7 +44,7 @@ class APPNP(ScalableGNN):
for history in self.histories: for history in self.histories:
x = (1 - self.alpha) * (adj_t @ x) + self.alpha * x_0 x = (1 - self.alpha) * (adj_t @ x) + self.alpha * x_0
x = self.push_and_pull(history, x, batch_size, n_id, offset, count) x = self.push_and_pull(history, x, *args)
x = (1 - self.alpha) * (adj_t @ x) + self.alpha * x_0 x = (1 - self.alpha) * (adj_t @ x) + self.alpha * x_0
return x return x
......
...@@ -14,7 +14,7 @@ class ScalableGNN(torch.nn.Module): ...@@ -14,7 +14,7 @@ class ScalableGNN(torch.nn.Module):
def __init__(self, num_nodes: int, hidden_channels: int, num_layers: int, def __init__(self, num_nodes: int, hidden_channels: int, num_layers: int,
pool_size: Optional[int] = None, pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None): buffer_size: Optional[int] = None, device=None):
super(ScalableGNN, self).__init__() super().__init__()
self.num_nodes = num_nodes self.num_nodes = num_nodes
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
...@@ -40,7 +40,7 @@ class ScalableGNN(torch.nn.Module): ...@@ -40,7 +40,7 @@ class ScalableGNN(torch.nn.Module):
return self.histories[0]._device return self.histories[0]._device
def _apply(self, fn: Callable) -> None: def _apply(self, fn: Callable) -> None:
super(ScalableGNN, self)._apply(fn) super()._apply(fn)
# We only initialize the AsyncIOPool in case histories are on CPU: # We only initialize the AsyncIOPool in case histories are on CPU:
if (str(self.emb_device) == 'cpu' and str(self.device)[:4] == 'cuda' if (str(self.emb_device) == 'cpu' and str(self.device)[:4] == 'cuda'
and self.pool_size is not None and self.pool_size is not None
......
...@@ -3,7 +3,7 @@ from typing import Optional ...@@ -3,7 +3,7 @@ from typing import Optional
import torch import torch
from torch import Tensor from torch import Tensor
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Linear, ModuleList from torch.nn import ModuleList
from torch_sparse import SparseTensor from torch_sparse import SparseTensor
from torch_geometric.nn import GATConv from torch_geometric.nn import GATConv
...@@ -13,17 +13,16 @@ from torch_geometric_autoscale.models import ScalableGNN ...@@ -13,17 +13,16 @@ from torch_geometric_autoscale.models import ScalableGNN
class GAT(ScalableGNN): class GAT(ScalableGNN):
def __init__(self, num_nodes: int, in_channels, hidden_channels: int, def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
hidden_heads: int, out_channels: int, out_heads: int, hidden_heads: int, out_channels: int, out_heads: int,
num_layers: int, residual: bool = False, dropout: float = 0.0, num_layers: int, dropout: float = 0.0,
pool_size: Optional[int] = None, pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None): buffer_size: Optional[int] = None, device=None):
super(GAT, self).__init__(num_nodes, hidden_channels * hidden_heads, super().__init__(num_nodes, hidden_channels * hidden_heads, num_layers,
num_layers, pool_size, buffer_size, device) pool_size, buffer_size, device)
self.in_channels = in_channels self.in_channels = in_channels
self.hidden_heads = hidden_heads self.hidden_heads = hidden_heads
self.out_channels = out_channels self.out_channels = out_channels
self.out_heads = out_heads self.out_heads = out_heads
self.residual = residual
self.dropout = dropout self.dropout = dropout
self.convs = ModuleList() self.convs = ModuleList()
...@@ -37,62 +36,33 @@ class GAT(ScalableGNN): ...@@ -37,62 +36,33 @@ class GAT(ScalableGNN):
concat=False, dropout=dropout, add_self_loops=False) concat=False, dropout=dropout, add_self_loops=False)
self.convs.append(conv) self.convs.append(conv)
self.lins = ModuleList() self.reg_modules = self.convs
if residual:
self.lins.append(
Linear(in_channels, hidden_channels * hidden_heads))
self.lins.append(
Linear(hidden_channels * hidden_heads, out_channels))
self.reg_modules = ModuleList([self.convs, self.lins])
self.nonreg_modules = ModuleList() self.nonreg_modules = ModuleList()
def reset_parameters(self): def reset_parameters(self):
super(GAT, self).reset_parameters() super().reset_parameters()
for conv in self.convs: for conv in self.convs:
conv.reset_parameters() conv.reset_parameters()
for lin in self.lins: for lin in self.lins:
lin.reset_parameters() lin.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor, def forward(self, x: Tensor, adj_t: SparseTensor, *args) -> Tensor:
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
for conv, history in zip(self.convs[:-1], self.histories): for conv, history in zip(self.convs[:-1], self.histories):
h = F.dropout(x, p=self.dropout, training=self.training)
h = conv((h, h[:adj_t.size(0)]), adj_t)
if self.residual:
x = F.dropout(x, p=self.dropout, training=self.training)
h += x if h.size(-1) == x.size(-1) else self.lins[0](x)
x = F.elu(h)
x = self.push_and_pull(history, x, batch_size, n_id, offset, count)
h = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[-1]((h, h[:adj_t.size(0)]), adj_t)
if self.residual:
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
h += self.lins[1](x) x = conv((x, x[:adj_t.size(0)]), adj_t)
return h x = F.elu(x)
x = self.push_and_pull(history, x, *args)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1]((x, x[:adj_t.size(0)]), adj_t)
return x
@torch.no_grad() @torch.no_grad()
def forward_layer(self, layer, x, adj_t, state): def forward_layer(self, layer, x, adj_t, state):
h = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[layer]((h, h[:adj_t.size(0)]), adj_t) x = self.convs[layer]((x, x[:adj_t.size(0)]), adj_t)
if layer == 0:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[0](x)
if layer == self.num_layers - 1:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[1](x)
if self.residual:
x = F.dropout(x, p=self.dropout, training=self.training)
h += x
if layer < self.num_layers - 1: if layer < self.num_layers - 1:
h = h.elu() x = x.elu()
return h return x
...@@ -17,8 +17,8 @@ class GCN(ScalableGNN): ...@@ -17,8 +17,8 @@ class GCN(ScalableGNN):
residual: bool = False, linear: bool = False, residual: bool = False, linear: bool = False,
pool_size: Optional[int] = None, pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None): buffer_size: Optional[int] = None, device=None):
super(GCN, self).__init__(num_nodes, hidden_channels, num_layers, super().__init__(num_nodes, hidden_channels, num_layers, pool_size,
pool_size, buffer_size, device) buffer_size, device)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -60,7 +60,7 @@ class GCN(ScalableGNN): ...@@ -60,7 +60,7 @@ class GCN(ScalableGNN):
return self.lins if self.linear else self.convs[-1:] return self.lins if self.linear else self.convs[-1:]
def reset_parameters(self): def reset_parameters(self):
super(GCN, self).reset_parameters() super().reset_parameters()
for lin in self.lins: for lin in self.lins:
lin.reset_parameters() lin.reset_parameters()
for conv in self.convs: for conv in self.convs:
...@@ -68,11 +68,7 @@ class GCN(ScalableGNN): ...@@ -68,11 +68,7 @@ class GCN(ScalableGNN):
for bn in self.bns: for bn in self.bns:
bn.reset_parameters() bn.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor, def forward(self, x: Tensor, adj_t: SparseTensor, *args) -> Tensor:
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
if self.drop_input: if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
...@@ -87,7 +83,7 @@ class GCN(ScalableGNN): ...@@ -87,7 +83,7 @@ class GCN(ScalableGNN):
if self.residual and h.size(-1) == x.size(-1): if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)] h += x[:h.size(0)]
x = h.relu_() x = h.relu_()
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count) x = self.push_and_pull(hist, x, *args)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[-1](x, adj_t) h = self.convs[-1](x, adj_t)
......
...@@ -18,8 +18,8 @@ class GCN2(ScalableGNN): ...@@ -18,8 +18,8 @@ class GCN2(ScalableGNN):
batch_norm: bool = False, residual: bool = False, batch_norm: bool = False, residual: bool = False,
pool_size: Optional[int] = None, pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None): buffer_size: Optional[int] = None, device=None):
super(GCN2, self).__init__(num_nodes, hidden_channels, num_layers, super().__init__(num_nodes, hidden_channels, num_layers, pool_size,
pool_size, buffer_size, device) buffer_size, device)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -53,7 +53,7 @@ class GCN2(ScalableGNN): ...@@ -53,7 +53,7 @@ class GCN2(ScalableGNN):
return self.lins return self.lins
def reset_parameters(self): def reset_parameters(self):
super(GCN2, self).reset_parameters() super().reset_parameters()
for lin in self.lins: for lin in self.lins:
lin.reset_parameters() lin.reset_parameters()
for conv in self.convs: for conv in self.convs:
...@@ -61,11 +61,7 @@ class GCN2(ScalableGNN): ...@@ -61,11 +61,7 @@ class GCN2(ScalableGNN):
for bn in self.bns: for bn in self.bns:
bn.reset_parameters() bn.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor, def forward(self, x: Tensor, adj_t: SparseTensor, *args) -> Tensor:
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
if self.drop_input: if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
...@@ -80,7 +76,7 @@ class GCN2(ScalableGNN): ...@@ -80,7 +76,7 @@ class GCN2(ScalableGNN):
if self.residual: if self.residual:
h += x[:h.size(0)] h += x[:h.size(0)]
x = h.relu_() x = h.relu_()
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count) x = self.push_and_pull(hist, x, *args)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[-1](x, x_0, adj_t) h = self.convs[-1](x, x_0, adj_t)
......
...@@ -17,8 +17,7 @@ class PNAConv(MessagePassing): ...@@ -17,8 +17,7 @@ class PNAConv(MessagePassing):
def __init__(self, in_channels: int, out_channels: int, def __init__(self, in_channels: int, out_channels: int,
aggregators: List[str], scalers: List[str], deg: Tensor, aggregators: List[str], scalers: List[str], deg: Tensor,
**kwargs): **kwargs):
super().__init__(aggr=None, **kwargs)
super(PNAConv, self).__init__(aggr=None, **kwargs)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -83,8 +82,8 @@ class PNA(ScalableGNN): ...@@ -83,8 +82,8 @@ class PNA(ScalableGNN):
drop_input: bool = True, batch_norm: bool = False, drop_input: bool = True, batch_norm: bool = False,
residual: bool = False, pool_size: Optional[int] = None, residual: bool = False, pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None): buffer_size: Optional[int] = None, device=None):
super(PNA, self).__init__(num_nodes, hidden_channels, num_layers, super().__init__(num_nodes, hidden_channels, num_layers, pool_size,
pool_size, buffer_size, device) buffer_size, device)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -115,17 +114,13 @@ class PNA(ScalableGNN): ...@@ -115,17 +114,13 @@ class PNA(ScalableGNN):
return self.convs[-1:] return self.convs[-1:]
def reset_parameters(self): def reset_parameters(self):
super(PNA, self).reset_parameters() super().reset_parameters()
for conv in self.convs: for conv in self.convs:
conv.reset_parameters() conv.reset_parameters()
for bn in self.bns: for bn in self.bns:
bn.reset_parameters() bn.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor, def forward(self, x: Tensor, adj_t: SparseTensor, *args) -> Tensor:
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
if self.drop_input: if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
...@@ -136,7 +131,7 @@ class PNA(ScalableGNN): ...@@ -136,7 +131,7 @@ class PNA(ScalableGNN):
if self.residual and h.size(-1) == x.size(-1): if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)] h += x[:h.size(0)]
x = h.relu_() x = h.relu_()
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count) x = self.push_and_pull(hist, x, *args)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, adj_t) x = self.convs[-1](x, adj_t)
......
...@@ -18,8 +18,8 @@ class PNA_JK(ScalableGNN): ...@@ -18,8 +18,8 @@ class PNA_JK(ScalableGNN):
drop_input: bool = True, batch_norm: bool = False, drop_input: bool = True, batch_norm: bool = False,
residual: bool = False, pool_size: Optional[int] = None, residual: bool = False, pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None): buffer_size: Optional[int] = None, device=None):
super(PNA_JK, self).__init__(num_nodes, hidden_channels, num_layers, super().__init__(num_nodes, hidden_channels, num_layers, pool_size,
pool_size, buffer_size, device) buffer_size, device)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -59,7 +59,7 @@ class PNA_JK(ScalableGNN): ...@@ -59,7 +59,7 @@ class PNA_JK(ScalableGNN):
return self.lins return self.lins
def reset_parameters(self): def reset_parameters(self):
super(PNA_JK, self).reset_parameters() super().reset_parameters()
for lin in self.lins: for lin in self.lins:
lin.reset_parameters() lin.reset_parameters()
for conv in self.convs: for conv in self.convs:
...@@ -67,11 +67,7 @@ class PNA_JK(ScalableGNN): ...@@ -67,11 +67,7 @@ class PNA_JK(ScalableGNN):
for bn in self.bns: for bn in self.bns:
bn.reset_parameters() bn.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor, def forward(self, x: Tensor, adj_t: SparseTensor, *args) -> Tensor:
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
if self.drop_input: if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
...@@ -87,7 +83,7 @@ class PNA_JK(ScalableGNN): ...@@ -87,7 +83,7 @@ class PNA_JK(ScalableGNN):
h += x[:h.size(0)] h += x[:h.size(0)]
x = h.relu_() x = h.relu_()
xs += [x] xs += [x]
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count) x = self.push_and_pull(hist, x, *args)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[-1](x, adj_t) h = self.convs[-1](x, adj_t)
...@@ -104,6 +100,8 @@ class PNA_JK(ScalableGNN): ...@@ -104,6 +100,8 @@ class PNA_JK(ScalableGNN):
@torch.no_grad() @torch.no_grad()
def forward_layer(self, layer, x, adj_t, state): def forward_layer(self, layer, x, adj_t, state):
# We keep the skip connections in GPU memory for now. If one encounters
# GPU memory problems, it is advised to push `state['xs']` to the CPU.
if layer == 0: if layer == 0:
if self.drop_input: if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
......
...@@ -11,7 +11,7 @@ write_async = torch.ops.torch_geometric_autoscale.write_async ...@@ -11,7 +11,7 @@ write_async = torch.ops.torch_geometric_autoscale.write_async
class AsyncIOPool(torch.nn.Module): class AsyncIOPool(torch.nn.Module):
def __init__(self, pool_size: int, buffer_size: int, embedding_dim: int): def __init__(self, pool_size: int, buffer_size: int, embedding_dim: int):
super(AsyncIOPool, self).__init__() super().__init__()
self.pool_size = pool_size self.pool_size = pool_size
self.buffer_size = buffer_size self.buffer_size = buffer_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment