Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pyg_autoscale
Commits
d3975fdc
Commit
d3975fdc
authored
Feb 04, 2021
by
rusty1s
Browse files
gat model
parent
c0aaaedd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
46 deletions
+36
-46
torch_geometric_autoscale/models/__init__.py
torch_geometric_autoscale/models/__init__.py
+2
-2
torch_geometric_autoscale/models/gat.py
torch_geometric_autoscale/models/gat.py
+31
-40
torch_geometric_autoscale/models/gcn.py
torch_geometric_autoscale/models/gcn.py
+3
-4
No files found.
torch_geometric_autoscale/models/__init__.py
View file @
d3975fdc
from
.base
import
ScalableGNN
from
.gcn
import
GCN
#
from .gat import GAT
from
.gat
import
GAT
# from .appnp import APPNP
# from .gcn2 import GCN2
# from .pna import PNA
...
...
@@ -9,7 +9,7 @@ from .gcn import GCN
__all__
=
[
'ScalableGNN'
,
'GCN'
,
#
'GAT',
'GAT'
,
# 'APPNP',
# 'GCN2',
# 'PNA',
...
...
torch_geometric_autoscale/models/gat.py
View file @
d3975fdc
...
...
@@ -7,16 +7,17 @@ from torch.nn import Linear, ModuleList
from
torch_sparse
import
SparseTensor
from
torch_geometric.nn
import
GATConv
from
.base
import
History
GNN
from
torch_geometric_autoscale.models
import
Scalable
GNN
class
GAT
(
History
GNN
):
class
GAT
(
Scalable
GNN
):
def
__init__
(
self
,
num_nodes
:
int
,
in_channels
,
hidden_channels
:
int
,
hidden_heads
:
int
,
out_channels
:
int
,
out_heads
:
int
,
num_layers
:
int
,
residual
:
bool
=
False
,
dropout
:
float
=
0.0
,
device
=
None
,
dtype
=
None
):
pool_size
:
Optional
[
int
]
=
None
,
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
super
(
GAT
,
self
).
__init__
(
num_nodes
,
hidden_channels
*
hidden_heads
,
num_layers
,
device
,
dtyp
e
)
num_layers
,
pool_size
,
buffer_size
,
devic
e
)
self
.
in_channels
=
in_channels
self
.
hidden_heads
=
hidden_heads
...
...
@@ -55,53 +56,43 @@ class GAT(HistoryGNN):
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
batch_size
:
Optional
[
int
]
=
None
,
n_id
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
n_id
:
Optional
[
Tensor
]
=
None
,
offset
:
Optional
[
Tensor
]
=
None
,
count
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
for
conv
,
history
in
zip
(
self
.
convs
[:
-
1
],
self
.
histories
):
h
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
conv
(
h
,
adj_t
)
h
=
conv
(
(
h
,
h
[:
adj_t
.
size
(
0
)])
,
adj_t
)
if
self
.
residual
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
h
+
x
if
h
.
size
(
-
1
)
==
x
.
size
(
-
1
)
else
h
+
self
.
lins
[
0
](
x
)
h
+
=
x
if
h
.
size
(
-
1
)
==
x
.
size
(
-
1
)
else
self
.
lins
[
0
](
x
)
x
=
F
.
elu
(
h
)
x
=
self
.
push_and_pull
(
history
,
x
,
batch_size
,
n_id
)
x
=
self
.
push_and_pull
(
history
,
x
,
batch_size
,
n_id
,
offset
,
count
)
h
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
convs
[
-
1
](
h
,
adj_t
)
h
=
self
.
convs
[
-
1
](
(
h
,
h
[:
adj_t
.
size
(
0
)])
,
adj_t
)
if
self
.
residual
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
h
+
self
.
lins
[
1
](
x
)
if
batch_size
is
not
None
:
h
=
h
[:
batch_size
]
h
+=
self
.
lins
[
1
](
x
)
return
h
@
torch
.
no_grad
()
def
mini_inference
(
self
,
x
:
Tensor
,
loader
)
->
Tensor
:
for
conv
,
history
in
zip
(
self
.
convs
[:
-
1
],
self
.
histories
):
for
info
in
loader
:
info
=
info
.
to
(
self
.
device
)
batch_size
,
n_id
,
adj_t
,
e_id
=
info
r
=
x
[
n_id
]
h
=
conv
(
r
,
adj_t
)
if
self
.
residual
:
if
h
.
size
(
-
1
)
==
r
.
size
(
-
1
):
h
=
h
+
r
else
:
h
=
h
+
self
.
lins
[
0
](
r
)
h
=
F
.
elu
(
h
)
history
.
push_
(
h
[:
batch_size
],
n_id
[:
batch_size
])
x
=
history
.
pull
()
out
=
x
.
new_empty
(
self
.
num_nodes
,
self
.
out_channels
)
for
info
in
loader
:
info
=
info
.
to
(
self
.
device
)
batch_size
,
n_id
,
adj_t
,
e_id
=
info
r
=
x
[
n_id
]
h
=
self
.
convs
[
-
1
](
r
,
adj_t
)[:
batch_size
]
if
self
.
residual
:
h
=
h
+
self
.
lins
[
1
](
r
)
out
[
n_id
[:
batch_size
]]
=
h
def
forward_layer
(
self
,
layer
,
x
,
adj_t
,
state
):
h
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
convs
[
layer
]((
h
,
h
[:
adj_t
.
size
(
0
)]),
adj_t
)
if
layer
==
0
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
lins
[
0
](
x
)
if
layer
==
self
.
num_layers
-
1
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
lins
[
1
](
x
)
if
self
.
residual
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
+=
x
if
layer
<
self
.
num_layers
-
1
:
h
=
h
.
elu
()
return
out
return
h
torch_geometric_autoscale/models/gcn.py
View file @
d3975fdc
...
...
@@ -28,8 +28,8 @@ class GCN(ScalableGNN):
self
.
residual
=
residual
self
.
linear
=
linear
self
.
lins
=
ModuleList
()
if
linear
:
self
.
lins
=
ModuleList
()
self
.
lins
.
append
(
Linear
(
in_channels
,
hidden_channels
))
self
.
lins
.
append
(
Linear
(
hidden_channels
,
out_channels
))
...
...
@@ -61,9 +61,8 @@ class GCN(ScalableGNN):
def
reset_parameters
(
self
):
super
(
GCN
,
self
).
reset_parameters
()
if
self
.
linear
:
for
lin
in
self
.
lins
:
lin
.
reset_parameters
()
for
lin
in
self
.
lins
:
lin
.
reset_parameters
()
for
conv
in
self
.
convs
:
conv
.
reset_parameters
()
for
bn
in
self
.
bns
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment