"torchvision/vscode:/vscode.git/clone" did not exist on "43dbfd2e397930a9e4595a8914eb0221b34a55d5"
Commit a4cbb359 authored by rusty1s's avatar rusty1s
Browse files

update model

parent 74b1b814
from .base import HistoryGNN
from .base import ScalableGNN
from .gcn import GCN
from .sage import SAGE
from .gat import GAT
from .appnp import APPNP
from .gcn2 import GCN2
from .gin import GIN
from .transformer import Transformer
from .pna import PNA
from .pna_jk import PNA_JK
__all__ = [
'HistoryGNN',
'ScalableGNN',
'GCN',
'SAGE',
'GAT',
'APPNP',
'GCN2',
'GIN',
'Transformer',
'PNA',
'PNA_JK',
]
from typing import Optional, Callable
from typing import Optional, Callable, Dict, Any
import warnings
......@@ -6,8 +6,7 @@ import torch
from torch import Tensor
from torch_sparse import SparseTensor
from scaling_gnns.history2 import History
from scaling_gnns.pool import AsyncIOPool
from torch_geometric_autoscale import History, AsyncIOPool, SubgraphLoader
class ScalableGNN(torch.nn.Module):
......@@ -125,7 +124,7 @@ class ScalableGNN(torch.nn.Module):
return out
@torch.no_grad()
def mini_inference(self, loader) -> Tensor:
def mini_inference(self, loader: SubgraphLoader) -> Tensor:
loader = [data + ({}, ) for data in loader]
for batch, batch_size, n_id, offset, count, state in loader:
......@@ -162,3 +161,8 @@ class ScalableGNN(torch.nn.Module):
self.pool.synchronize_push()
return self._out
@torch.no_grad()
def forward_layer(self, layer: int, x: Tensor, adj_t: SparseTensor,
state: Dict[Any]) -> Tensor:
raise NotImplementedError
from typing import Optional
from typing import Optional, Dict, Any
import torch
from torch import Tensor
......@@ -75,7 +75,8 @@ class GCN(ScalableGNN):
return x
@torch.no_grad()
def forward_layer(self, layer, x, adj_t, state):
def forward_layer(self, layer: int, x: Tensor, adj_t: SparseTensor,
state: Dict[Any]) -> Tensor:
if layer == 0 and self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment