Commit d3975fdc authored by rusty1s's avatar rusty1s
Browse files

gat model

parent c0aaaedd
from .base import ScalableGNN
from .gcn import GCN
# from .gat import GAT
from .gat import GAT
# from .appnp import APPNP
# from .gcn2 import GCN2
# from .pna import PNA
......@@ -9,7 +9,7 @@ from .gcn import GCN
__all__ = [
'ScalableGNN',
'GCN',
# 'GAT',
'GAT',
# 'APPNP',
# 'GCN2',
# 'PNA',
......
......@@ -7,16 +7,17 @@ from torch.nn import Linear, ModuleList
from torch_sparse import SparseTensor
from torch_geometric.nn import GATConv
from .base import HistoryGNN
from torch_geometric_autoscale.models import ScalableGNN
class GAT(HistoryGNN):
class GAT(ScalableGNN):
def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
hidden_heads: int, out_channels: int, out_heads: int,
num_layers: int, residual: bool = False, dropout: float = 0.0,
device=None, dtype=None):
pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(GAT, self).__init__(num_nodes, hidden_channels * hidden_heads,
num_layers, device, dtype)
num_layers, pool_size, buffer_size, device)
self.in_channels = in_channels
self.hidden_heads = hidden_heads
......@@ -55,53 +56,43 @@ class GAT(HistoryGNN):
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None) -> Tensor:
n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
for conv, history in zip(self.convs[:-1], self.histories):
h = F.dropout(x, p=self.dropout, training=self.training)
h = conv(h, adj_t)
h = conv((h, h[:adj_t.size(0)]), adj_t)
if self.residual:
x = F.dropout(x, p=self.dropout, training=self.training)
h = h + x if h.size(-1) == x.size(-1) else h + self.lins[0](x)
h += x if h.size(-1) == x.size(-1) else self.lins[0](x)
x = F.elu(h)
x = self.push_and_pull(history, x, batch_size, n_id)
x = self.push_and_pull(history, x, batch_size, n_id, offset, count)
h = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[-1](h, adj_t)
h = self.convs[-1]((h, h[:adj_t.size(0)]), adj_t)
if self.residual:
x = F.dropout(x, p=self.dropout, training=self.training)
h = h + self.lins[1](x)
if batch_size is not None:
h = h[:batch_size]
h += self.lins[1](x)
return h
@torch.no_grad()
def mini_inference(self, x: Tensor, loader) -> Tensor:
for conv, history in zip(self.convs[:-1], self.histories):
for info in loader:
info = info.to(self.device)
batch_size, n_id, adj_t, e_id = info
r = x[n_id]
h = conv(r, adj_t)
if self.residual:
if h.size(-1) == r.size(-1):
h = h + r
else:
h = h + self.lins[0](r)
h = F.elu(h)
history.push_(h[:batch_size], n_id[:batch_size])
x = history.pull()
out = x.new_empty(self.num_nodes, self.out_channels)
for info in loader:
info = info.to(self.device)
batch_size, n_id, adj_t, e_id = info
r = x[n_id]
h = self.convs[-1](r, adj_t)[:batch_size]
if self.residual:
h = h + self.lins[1](r)
out[n_id[:batch_size]] = h
def forward_layer(self, layer, x, adj_t, state):
h = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[layer]((h, h[:adj_t.size(0)]), adj_t)
if layer == 0:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[0](x)
if layer == self.num_layers - 1:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[1](x)
if self.residual:
x = F.dropout(x, p=self.dropout, training=self.training)
h += x
if layer < self.num_layers - 1:
h = h.elu()
return out
return h
......@@ -28,8 +28,8 @@ class GCN(ScalableGNN):
self.residual = residual
self.linear = linear
self.lins = ModuleList()
if linear:
self.lins = ModuleList()
self.lins.append(Linear(in_channels, hidden_channels))
self.lins.append(Linear(hidden_channels, out_channels))
......@@ -61,9 +61,8 @@ class GCN(ScalableGNN):
def reset_parameters(self):
super(GCN, self).reset_parameters()
if self.linear:
for lin in self.lins:
lin.reset_parameters()
for lin in self.lins:
lin.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment