Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pyg_autoscale
Commits
a4cbb359
"torchvision/vscode:/vscode.git/clone" did not exist on "43dbfd2e397930a9e4595a8914eb0221b34a55d5"
Commit
a4cbb359
authored
Feb 02, 2021
by
rusty1s
Browse files
update model
parent
74b1b814
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
12 deletions
+13
-12
torch_geometric_autoscale/models/__init__.py
torch_geometric_autoscale/models/__init__.py
+2
-6
torch_geometric_autoscale/models/base.py
torch_geometric_autoscale/models/base.py
+8
-4
torch_geometric_autoscale/models/gcn.py
torch_geometric_autoscale/models/gcn.py
+3
-2
No files found.
torch_geometric_autoscale/models/__init__.py
View file @
a4cbb359
from
.base
import
History
GNN
from
.base
import
Scalable
GNN
from
.gcn
import
GCN
from
.sage
import
SAGE
from
.gat
import
GAT
from
.appnp
import
APPNP
from
.gcn2
import
GCN2
from
.gin
import
GIN
from
.transformer
import
Transformer
from
.pna
import
PNA
from
.pna_jk
import
PNA_JK
__all__
=
[
'
History
GNN'
,
'
Scalable
GNN'
,
'GCN'
,
'SAGE'
,
'GAT'
,
'APPNP'
,
'GCN2'
,
'GIN'
,
'Transformer'
,
'PNA'
,
'PNA_JK'
,
]
torch_geometric_autoscale/models/base.py
View file @
a4cbb359
from
typing
import
Optional
,
Callable
from
typing
import
Optional
,
Callable
,
Dict
,
Any
import
warnings
...
...
@@ -6,8 +6,7 @@ import torch
from
torch
import
Tensor
from
torch_sparse
import
SparseTensor
from
scaling_gnns.history2
import
History
from
scaling_gnns.pool
import
AsyncIOPool
from
torch_geometric_autoscale
import
History
,
AsyncIOPool
,
SubgraphLoader
class
ScalableGNN
(
torch
.
nn
.
Module
):
...
...
@@ -125,7 +124,7 @@ class ScalableGNN(torch.nn.Module):
return
out
@
torch
.
no_grad
()
def
mini_inference
(
self
,
loader
)
->
Tensor
:
def
mini_inference
(
self
,
loader
:
SubgraphLoader
)
->
Tensor
:
loader
=
[
data
+
({},
)
for
data
in
loader
]
for
batch
,
batch_size
,
n_id
,
offset
,
count
,
state
in
loader
:
...
...
@@ -162,3 +161,8 @@ class ScalableGNN(torch.nn.Module):
self
.
pool
.
synchronize_push
()
return
self
.
_out
@
torch
.
no_grad
()
def
forward_layer
(
self
,
layer
:
int
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
state
:
Dict
[
Any
])
->
Tensor
:
raise
NotImplementedError
torch_geometric_autoscale/models/gcn.py
View file @
a4cbb359
from
typing
import
Optional
from
typing
import
Optional
,
Dict
,
Any
import
torch
from
torch
import
Tensor
...
...
@@ -75,7 +75,8 @@ class GCN(ScalableGNN):
return
x
@
torch
.
no_grad
()
def
forward_layer
(
self
,
layer
,
x
,
adj_t
,
state
):
def
forward_layer
(
self
,
layer
:
int
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
state
:
Dict
[
Any
])
->
Tensor
:
if
layer
==
0
and
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment