Commit a4cbb359 authored by rusty1s's avatar rusty1s
Browse files

update model

parent 74b1b814
from .base import HistoryGNN from .base import ScalableGNN
from .gcn import GCN from .gcn import GCN
from .sage import SAGE
from .gat import GAT from .gat import GAT
from .appnp import APPNP from .appnp import APPNP
from .gcn2 import GCN2 from .gcn2 import GCN2
from .gin import GIN from .gin import GIN
from .transformer import Transformer
from .pna import PNA from .pna import PNA
from .pna_jk import PNA_JK from .pna_jk import PNA_JK
__all__ = [ __all__ = [
'HistoryGNN', 'ScalableGNN',
'GCN', 'GCN',
'SAGE',
'GAT', 'GAT',
'APPNP', 'APPNP',
'GCN2', 'GCN2',
'GIN', 'GIN',
'Transformer',
'PNA', 'PNA',
'PNA_JK', 'PNA_JK',
] ]
from typing import Optional, Callable from typing import Optional, Callable, Dict, Any
import warnings import warnings
...@@ -6,8 +6,7 @@ import torch ...@@ -6,8 +6,7 @@ import torch
from torch import Tensor from torch import Tensor
from torch_sparse import SparseTensor from torch_sparse import SparseTensor
from scaling_gnns.history2 import History from torch_geometric_autoscale import History, AsyncIOPool, SubgraphLoader
from scaling_gnns.pool import AsyncIOPool
class ScalableGNN(torch.nn.Module): class ScalableGNN(torch.nn.Module):
...@@ -125,7 +124,7 @@ class ScalableGNN(torch.nn.Module): ...@@ -125,7 +124,7 @@ class ScalableGNN(torch.nn.Module):
return out return out
@torch.no_grad() @torch.no_grad()
def mini_inference(self, loader) -> Tensor: def mini_inference(self, loader: SubgraphLoader) -> Tensor:
loader = [data + ({}, ) for data in loader] loader = [data + ({}, ) for data in loader]
for batch, batch_size, n_id, offset, count, state in loader: for batch, batch_size, n_id, offset, count, state in loader:
...@@ -162,3 +161,8 @@ class ScalableGNN(torch.nn.Module): ...@@ -162,3 +161,8 @@ class ScalableGNN(torch.nn.Module):
self.pool.synchronize_push() self.pool.synchronize_push()
return self._out return self._out
@torch.no_grad()
def forward_layer(self, layer: int, x: Tensor, adj_t: SparseTensor,
state: Dict[Any]) -> Tensor:
raise NotImplementedError
from typing import Optional from typing import Optional, Dict, Any
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -75,7 +75,8 @@ class GCN(ScalableGNN): ...@@ -75,7 +75,8 @@ class GCN(ScalableGNN):
return x return x
@torch.no_grad() @torch.no_grad()
def forward_layer(self, layer, x, adj_t, state): def forward_layer(self, layer: int, x: Tensor, adj_t: SparseTensor,
state: Dict[Any]) -> Tensor:
if layer == 0 and self.drop_input: if layer == 0 and self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment