Commit d3975fdc authored by rusty1s's avatar rusty1s
Browse files

gat model

parent c0aaaedd
from .base import ScalableGNN from .base import ScalableGNN
from .gcn import GCN from .gcn import GCN
# from .gat import GAT from .gat import GAT
# from .appnp import APPNP # from .appnp import APPNP
# from .gcn2 import GCN2 # from .gcn2 import GCN2
# from .pna import PNA # from .pna import PNA
...@@ -9,7 +9,7 @@ from .gcn import GCN ...@@ -9,7 +9,7 @@ from .gcn import GCN
__all__ = [ __all__ = [
'ScalableGNN', 'ScalableGNN',
'GCN', 'GCN',
# 'GAT', 'GAT',
# 'APPNP', # 'APPNP',
# 'GCN2', # 'GCN2',
# 'PNA', # 'PNA',
......
...@@ -7,16 +7,17 @@ from torch.nn import Linear, ModuleList ...@@ -7,16 +7,17 @@ from torch.nn import Linear, ModuleList
from torch_sparse import SparseTensor from torch_sparse import SparseTensor
from torch_geometric.nn import GATConv from torch_geometric.nn import GATConv
from .base import HistoryGNN from torch_geometric_autoscale.models import ScalableGNN
class GAT(HistoryGNN): class GAT(ScalableGNN):
def __init__(self, num_nodes: int, in_channels, hidden_channels: int, def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
hidden_heads: int, out_channels: int, out_heads: int, hidden_heads: int, out_channels: int, out_heads: int,
num_layers: int, residual: bool = False, dropout: float = 0.0, num_layers: int, residual: bool = False, dropout: float = 0.0,
device=None, dtype=None): pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(GAT, self).__init__(num_nodes, hidden_channels * hidden_heads, super(GAT, self).__init__(num_nodes, hidden_channels * hidden_heads,
num_layers, device, dtype) num_layers, pool_size, buffer_size, device)
self.in_channels = in_channels self.in_channels = in_channels
self.hidden_heads = hidden_heads self.hidden_heads = hidden_heads
...@@ -55,53 +56,43 @@ class GAT(HistoryGNN): ...@@ -55,53 +56,43 @@ class GAT(HistoryGNN):
def forward(self, x: Tensor, adj_t: SparseTensor, def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None, batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None) -> Tensor: n_id: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None) -> Tensor:
for conv, history in zip(self.convs[:-1], self.histories): for conv, history in zip(self.convs[:-1], self.histories):
h = F.dropout(x, p=self.dropout, training=self.training) h = F.dropout(x, p=self.dropout, training=self.training)
h = conv(h, adj_t) h = conv((h, h[:adj_t.size(0)]), adj_t)
if self.residual: if self.residual:
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
h = h + x if h.size(-1) == x.size(-1) else h + self.lins[0](x) h += x if h.size(-1) == x.size(-1) else self.lins[0](x)
x = F.elu(h) x = F.elu(h)
x = self.push_and_pull(history, x, batch_size, n_id) x = self.push_and_pull(history, x, batch_size, n_id, offset, count)
h = F.dropout(x, p=self.dropout, training=self.training) h = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[-1](h, adj_t) h = self.convs[-1]((h, h[:adj_t.size(0)]), adj_t)
if self.residual: if self.residual:
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
h = h + self.lins[1](x) h += self.lins[1](x)
if batch_size is not None:
h = h[:batch_size]
return h return h
@torch.no_grad() @torch.no_grad()
def mini_inference(self, x: Tensor, loader) -> Tensor: def forward_layer(self, layer, x, adj_t, state):
for conv, history in zip(self.convs[:-1], self.histories): h = F.dropout(x, p=self.dropout, training=self.training)
for info in loader: h = self.convs[layer]((h, h[:adj_t.size(0)]), adj_t)
info = info.to(self.device)
batch_size, n_id, adj_t, e_id = info if layer == 0:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[0](x)
if layer == self.num_layers - 1:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[1](x)
r = x[n_id]
h = conv(r, adj_t)
if self.residual:
if h.size(-1) == r.size(-1):
h = h + r
else:
h = h + self.lins[0](r)
h = F.elu(h)
history.push_(h[:batch_size], n_id[:batch_size])
x = history.pull()
out = x.new_empty(self.num_nodes, self.out_channels)
for info in loader:
info = info.to(self.device)
batch_size, n_id, adj_t, e_id = info
r = x[n_id]
h = self.convs[-1](r, adj_t)[:batch_size]
if self.residual: if self.residual:
h = h + self.lins[1](r) x = F.dropout(x, p=self.dropout, training=self.training)
out[n_id[:batch_size]] = h h += x
return out if layer < self.num_layers - 1:
h = h.elu()
return h
...@@ -28,8 +28,8 @@ class GCN(ScalableGNN): ...@@ -28,8 +28,8 @@ class GCN(ScalableGNN):
self.residual = residual self.residual = residual
self.linear = linear self.linear = linear
if linear:
self.lins = ModuleList() self.lins = ModuleList()
if linear:
self.lins.append(Linear(in_channels, hidden_channels)) self.lins.append(Linear(in_channels, hidden_channels))
self.lins.append(Linear(hidden_channels, out_channels)) self.lins.append(Linear(hidden_channels, out_channels))
...@@ -61,7 +61,6 @@ class GCN(ScalableGNN): ...@@ -61,7 +61,6 @@ class GCN(ScalableGNN):
def reset_parameters(self): def reset_parameters(self):
super(GCN, self).reset_parameters() super(GCN, self).reset_parameters()
if self.linear:
for lin in self.lins: for lin in self.lins:
lin.reset_parameters() lin.reset_parameters()
for conv in self.convs: for conv in self.convs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment