Commit 38ca4fb2 authored by rusty1s's avatar rusty1s
Browse files

update super calls

parent 07932207
......@@ -10,7 +10,6 @@ params:
hidden_channels: 8
hidden_heads: 8
out_heads: 1
residual: false
dropout: 0.6
num_parts: 40
batch_size: 10
......@@ -28,7 +27,6 @@ params:
hidden_channels: 8
hidden_heads: 8
out_heads: 1
residual: false
dropout: 0.6
num_parts: 24
batch_size: 8
......@@ -46,7 +44,6 @@ params:
hidden_channels: 8
hidden_heads: 8
out_heads: 8
residual: false
dropout: 0.6
num_parts: 4
batch_size: 1
......@@ -64,7 +61,6 @@ params:
hidden_channels: 8
hidden_heads: 8
out_heads: 1
residual: false
dropout: 0.6
num_parts: 8
batch_size: 2
......@@ -82,7 +78,6 @@ params:
hidden_channels: 8
hidden_heads: 8
out_heads: 1
residual: false
dropout: 0.6
num_parts: 4
batch_size: 1
......@@ -100,7 +95,6 @@ params:
hidden_channels: 14
hidden_heads: 5
out_heads: 1
residual: false
dropout: 0.5
num_parts: 2
batch_size: 1
......
......@@ -14,8 +14,8 @@ class APPNP(ScalableGNN):
out_channels: int, num_layers: int, alpha: float,
dropout: float = 0.0, pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(APPNP, self).__init__(num_nodes, out_channels, num_layers,
pool_size, buffer_size, device)
super().__init__(num_nodes, out_channels, num_layers, pool_size,
buffer_size, device)
self.in_channels = in_channels
self.out_channels = out_channels
......@@ -30,15 +30,11 @@ class APPNP(ScalableGNN):
self.nonreg_modules = self.lins[1:]
def reset_parameters(self):
super(APPNP, self).reset_parameters()
super().reset_parameters()
for lin in self.lins:
lin.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, adj_t: SparseTensor, *args) -> Tensor:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[0](x)
x = x.relu()
......@@ -48,7 +44,7 @@ class APPNP(ScalableGNN):
for history in self.histories:
x = (1 - self.alpha) * (adj_t @ x) + self.alpha * x_0
x = self.push_and_pull(history, x, batch_size, n_id, offset, count)
x = self.push_and_pull(history, x, *args)
x = (1 - self.alpha) * (adj_t @ x) + self.alpha * x_0
return x
......
......@@ -14,7 +14,7 @@ class ScalableGNN(torch.nn.Module):
def __init__(self, num_nodes: int, hidden_channels: int, num_layers: int,
pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(ScalableGNN, self).__init__()
super().__init__()
self.num_nodes = num_nodes
self.hidden_channels = hidden_channels
......@@ -40,7 +40,7 @@ class ScalableGNN(torch.nn.Module):
return self.histories[0]._device
def _apply(self, fn: Callable) -> None:
super(ScalableGNN, self)._apply(fn)
super()._apply(fn)
# We only initialize the AsyncIOPool in case histories are on CPU:
if (str(self.emb_device) == 'cpu' and str(self.device)[:4] == 'cuda'
and self.pool_size is not None
......
......@@ -3,7 +3,7 @@ from typing import Optional
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Linear, ModuleList
from torch.nn import ModuleList
from torch_sparse import SparseTensor
from torch_geometric.nn import GATConv
......@@ -13,17 +13,16 @@ from torch_geometric_autoscale.models import ScalableGNN
class GAT(ScalableGNN):
def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
hidden_heads: int, out_channels: int, out_heads: int,
num_layers: int, residual: bool = False, dropout: float = 0.0,
num_layers: int, dropout: float = 0.0,
pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(GAT, self).__init__(num_nodes, hidden_channels * hidden_heads,
num_layers, pool_size, buffer_size, device)
super().__init__(num_nodes, hidden_channels * hidden_heads, num_layers,
pool_size, buffer_size, device)
self.in_channels = in_channels
self.hidden_heads = hidden_heads
self.out_channels = out_channels
self.out_heads = out_heads
self.residual = residual
self.dropout = dropout
self.convs = ModuleList()
......@@ -37,62 +36,33 @@ class GAT(ScalableGNN):
concat=False, dropout=dropout, add_self_loops=False)
self.convs.append(conv)
self.lins = ModuleList()
if residual:
self.lins.append(
Linear(in_channels, hidden_channels * hidden_heads))
self.lins.append(
Linear(hidden_channels * hidden_heads, out_channels))
self.reg_modules = ModuleList([self.convs, self.lins])
self.reg_modules = self.convs
self.nonreg_modules = ModuleList()
def reset_parameters(self):
super(GAT, self).reset_parameters()
super().reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for lin in self.lins:
lin.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, adj_t: SparseTensor, *args) -> Tensor:
for conv, history in zip(self.convs[:-1], self.histories):
h = F.dropout(x, p=self.dropout, training=self.training)
h = conv((h, h[:adj_t.size(0)]), adj_t)
if self.residual:
x = F.dropout(x, p=self.dropout, training=self.training)
h += x if h.size(-1) == x.size(-1) else self.lins[0](x)
x = F.elu(h)
x = self.push_and_pull(history, x, batch_size, n_id, offset, count)
h = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[-1]((h, h[:adj_t.size(0)]), adj_t)
if self.residual:
x = F.dropout(x, p=self.dropout, training=self.training)
h += self.lins[1](x)
return h
x = conv((x, x[:adj_t.size(0)]), adj_t)
x = F.elu(x)
x = self.push_and_pull(history, x, *args)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1]((x, x[:adj_t.size(0)]), adj_t)
return x
@torch.no_grad()
def forward_layer(self, layer, x, adj_t, state):
h = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[layer]((h, h[:adj_t.size(0)]), adj_t)
if layer == 0:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[0](x)
if layer == self.num_layers - 1:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[1](x)
if self.residual:
x = F.dropout(x, p=self.dropout, training=self.training)
h += x
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[layer]((x, x[:adj_t.size(0)]), adj_t)
if layer < self.num_layers - 1:
h = h.elu()
x = x.elu()
return h
return x
......@@ -17,8 +17,8 @@ class GCN(ScalableGNN):
residual: bool = False, linear: bool = False,
pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(GCN, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size, buffer_size, device)
super().__init__(num_nodes, hidden_channels, num_layers, pool_size,
buffer_size, device)
self.in_channels = in_channels
self.out_channels = out_channels
......@@ -60,7 +60,7 @@ class GCN(ScalableGNN):
return self.lins if self.linear else self.convs[-1:]
def reset_parameters(self):
super(GCN, self).reset_parameters()
super().reset_parameters()
for lin in self.lins:
lin.reset_parameters()
for conv in self.convs:
......@@ -68,11 +68,7 @@ class GCN(ScalableGNN):
for bn in self.bns:
bn.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, adj_t: SparseTensor, *args) -> Tensor:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
......@@ -87,7 +83,7 @@ class GCN(ScalableGNN):
if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)]
x = h.relu_()
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count)
x = self.push_and_pull(hist, x, *args)
x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[-1](x, adj_t)
......
......@@ -18,8 +18,8 @@ class GCN2(ScalableGNN):
batch_norm: bool = False, residual: bool = False,
pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(GCN2, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size, buffer_size, device)
super().__init__(num_nodes, hidden_channels, num_layers, pool_size,
buffer_size, device)
self.in_channels = in_channels
self.out_channels = out_channels
......@@ -53,7 +53,7 @@ class GCN2(ScalableGNN):
return self.lins
def reset_parameters(self):
super(GCN2, self).reset_parameters()
super().reset_parameters()
for lin in self.lins:
lin.reset_parameters()
for conv in self.convs:
......@@ -61,11 +61,7 @@ class GCN2(ScalableGNN):
for bn in self.bns:
bn.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, adj_t: SparseTensor, *args) -> Tensor:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
......@@ -80,7 +76,7 @@ class GCN2(ScalableGNN):
if self.residual:
h += x[:h.size(0)]
x = h.relu_()
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count)
x = self.push_and_pull(hist, x, *args)
x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[-1](x, x_0, adj_t)
......
......@@ -17,8 +17,7 @@ class PNAConv(MessagePassing):
def __init__(self, in_channels: int, out_channels: int,
aggregators: List[str], scalers: List[str], deg: Tensor,
**kwargs):
super(PNAConv, self).__init__(aggr=None, **kwargs)
super().__init__(aggr=None, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
......@@ -83,8 +82,8 @@ class PNA(ScalableGNN):
drop_input: bool = True, batch_norm: bool = False,
residual: bool = False, pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(PNA, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size, buffer_size, device)
super().__init__(num_nodes, hidden_channels, num_layers, pool_size,
buffer_size, device)
self.in_channels = in_channels
self.out_channels = out_channels
......@@ -115,17 +114,13 @@ class PNA(ScalableGNN):
return self.convs[-1:]
def reset_parameters(self):
super(PNA, self).reset_parameters()
super().reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, adj_t: SparseTensor, *args) -> Tensor:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
......@@ -136,7 +131,7 @@ class PNA(ScalableGNN):
if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)]
x = h.relu_()
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count)
x = self.push_and_pull(hist, x, *args)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, adj_t)
......
......@@ -18,8 +18,8 @@ class PNA_JK(ScalableGNN):
drop_input: bool = True, batch_norm: bool = False,
residual: bool = False, pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(PNA_JK, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size, buffer_size, device)
super().__init__(num_nodes, hidden_channels, num_layers, pool_size,
buffer_size, device)
self.in_channels = in_channels
self.out_channels = out_channels
......@@ -59,7 +59,7 @@ class PNA_JK(ScalableGNN):
return self.lins
def reset_parameters(self):
super(PNA_JK, self).reset_parameters()
super().reset_parameters()
for lin in self.lins:
lin.reset_parameters()
for conv in self.convs:
......@@ -67,11 +67,7 @@ class PNA_JK(ScalableGNN):
for bn in self.bns:
bn.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, adj_t: SparseTensor, *args) -> Tensor:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
......@@ -87,7 +83,7 @@ class PNA_JK(ScalableGNN):
h += x[:h.size(0)]
x = h.relu_()
xs += [x]
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count)
x = self.push_and_pull(hist, x, *args)
x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[-1](x, adj_t)
......@@ -104,6 +100,8 @@ class PNA_JK(ScalableGNN):
@torch.no_grad()
def forward_layer(self, layer, x, adj_t, state):
# We keep the skip connections in GPU memory for now. If one encounters
# GPU memory problems, it is advised to push `state['xs']` to the CPU.
if layer == 0:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
......
......@@ -11,7 +11,7 @@ write_async = torch.ops.torch_geometric_autoscale.write_async
class AsyncIOPool(torch.nn.Module):
def __init__(self, pool_size: int, buffer_size: int, embedding_dim: int):
super(AsyncIOPool, self).__init__()
super().__init__()
self.pool_size = pool_size
self.buffer_size = buffer_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment