Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pyg_autoscale
Commits
c427300e
Commit
c427300e
authored
Jun 07, 2021
by
rusty1s
Browse files
clean up
parent
d8d31882
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
24 deletions
+39
-24
README.md
README.md
+13
-3
torch_geometric_autoscale/data.py
torch_geometric_autoscale/data.py
+12
-11
torch_geometric_autoscale/metis.py
torch_geometric_autoscale/metis.py
+2
-2
torch_geometric_autoscale/models/base.py
torch_geometric_autoscale/models/base.py
+12
-8
No files found.
README.md
View file @
c427300e
...
@@ -21,7 +21,8 @@ from torch_geometric_autoscale import ScalableGNN
...
@@ -21,7 +21,8 @@ from torch_geometric_autoscale import ScalableGNN
class
GNN
(
ScalableGNN
):
class
GNN
(
ScalableGNN
):
def
__init__
(
self
,
num_nodes
,
in_channels
,
hidden_channels
,
out_channels
,
num_layers
):
def
__init__
(
self
,
num_nodes
,
in_channels
,
hidden_channels
,
out_channels
,
num_layers
):
super
(
GNN
,
self
).
__init__
(
num_nodes
,
hidden_channels
,
num_layers
)
super
(
GNN
,
self
).
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
pool_size
=
2
,
buffer_size
=
5000
)
self
.
convs
=
ModuleList
()
self
.
convs
=
ModuleList
()
self
.
convs
.
append
(
GCNConv
(
in_channels
,
hidden_channels
))
self
.
convs
.
append
(
GCNConv
(
in_channels
,
hidden_channels
))
...
@@ -29,11 +30,20 @@ class GNN(ScalableGNN):
...
@@ -29,11 +30,20 @@ class GNN(ScalableGNN):
self
.
convs
.
append
(
GCNConv
(
hidden_channels
,
hidden_channels
))
self
.
convs
.
append
(
GCNConv
(
hidden_channels
,
hidden_channels
))
self
.
convs
.
append
(
GCNConv
(
hidden_channels
,
out_channels
))
self
.
convs
.
append
(
GCNConv
(
hidden_channels
,
out_channels
))
def
forward
(
self
,
x
,
adj_t
,
batch_size
,
n_id
):
def
forward
(
self
,
x
,
adj_t
,
*
args
):
for
conv
,
history
in
zip
(
self
.
convs
[:
-
1
],
self
.
histories
):
for
conv
,
history
in
zip
(
self
.
convs
[:
-
1
],
self
.
histories
):
x
=
conv
(
x
,
adj_t
).
relu_
()
x
=
conv
(
x
,
adj_t
).
relu_
()
x
=
self
.
push_and_pull
(
history
,
x
,
batch_size
,
n_id
)
x
=
self
.
push_and_pull
(
history
,
x
,
*
args
)
return
self
.
convs
[
-
1
](
x
,
adj_t
)
return
self
.
convs
[
-
1
](
x
,
adj_t
)
perm
,
ptr
=
metis
(
data
.
adj_t
,
num_parts
=
40
,
log
=
True
)
data
=
permute
(
data
,
perm
,
log
=
True
)
loader
=
SubgraphLoader
(
data
,
ptr
,
batch_size
=
10
,
shuffle
=
True
)
for
batch
,
*
args
in
loader
:
out
=
model
(
batch
.
x
,
batch
.
adj_t
,
*
args
)
```
```
## Requirements
## Requirements
...
...
torch_geometric_autoscale/data.py
View file @
c427300e
...
@@ -110,24 +110,25 @@ def get_sbm(root: str, name: str) -> Tuple[Data, int, int]:
...
@@ -110,24 +110,25 @@ def get_sbm(root: str, name: str) -> Tuple[Data, int, int]:
def
get_data
(
root
:
str
,
name
:
str
)
->
Tuple
[
Data
,
int
,
int
]:
def
get_data
(
root
:
str
,
name
:
str
)
->
Tuple
[
Data
,
int
,
int
]:
if
name
.
lower
()
in
[
'cora'
,
'citeseer'
,
'pubmed'
]:
if
name
.
lower
()
in
[
'cora'
,
'citeseer'
,
'pubmed'
]:
return
get_planetoid
(
root
,
name
)
return
get_planetoid
(
root
,
name
)
if
name
.
lower
()
in
[
'coauthorcs'
,
'coauthorphysics'
]:
el
if
name
.
lower
()
in
[
'coauthorcs'
,
'coauthorphysics'
]:
return
get_coauthor
(
root
,
name
[
8
:])
return
get_coauthor
(
root
,
name
[
8
:])
if
name
.
lower
()
in
[
'amazoncomputers'
,
'amazonphoto'
]:
el
if
name
.
lower
()
in
[
'amazoncomputers'
,
'amazonphoto'
]:
return
get_amazon
(
root
,
name
[
6
:])
return
get_amazon
(
root
,
name
[
6
:])
if
name
.
lower
()
==
'wikics'
:
el
if
name
.
lower
()
==
'wikics'
:
return
get_wikics
(
root
)
return
get_wikics
(
root
)
if
name
.
lower
()
in
[
'cluster'
,
'pattern'
]:
el
if
name
.
lower
()
in
[
'cluster'
,
'pattern'
]:
return
get_sbm
(
root
,
name
)
return
get_sbm
(
root
,
name
)
if
name
.
lower
()
==
'reddit'
:
el
if
name
.
lower
()
==
'reddit'
:
return
get_reddit
(
root
)
return
get_reddit
(
root
)
if
name
.
lower
()
==
'ppi'
:
el
if
name
.
lower
()
==
'ppi'
:
return
get_ppi
(
root
)
return
get_ppi
(
root
)
if
name
.
lower
()
==
'flickr'
:
el
if
name
.
lower
()
==
'flickr'
:
return
get_flickr
(
root
)
return
get_flickr
(
root
)
if
name
.
lower
()
==
'yelp'
:
el
if
name
.
lower
()
==
'yelp'
:
return
get_yelp
(
root
)
return
get_yelp
(
root
)
if
name
.
lower
()
in
[
'ogbn-arxiv'
,
'arxiv'
]:
el
if
name
.
lower
()
in
[
'ogbn-arxiv'
,
'arxiv'
]:
return
get_arxiv
(
root
)
return
get_arxiv
(
root
)
if
name
.
lower
()
in
[
'ogbn-products'
,
'products'
]:
el
if
name
.
lower
()
in
[
'ogbn-products'
,
'products'
]:
return
get_products
(
root
)
return
get_products
(
root
)
raise
NotImplementedError
else
:
raise
NotImplementedError
torch_geometric_autoscale/metis.py
View file @
c427300e
...
@@ -46,9 +46,9 @@ def permute(data: Union[Data, SparseTensor], perm: Tensor,
...
@@ -46,9 +46,9 @@ def permute(data: Union[Data, SparseTensor], perm: Tensor,
for
key
,
item
in
data
:
for
key
,
item
in
data
:
if
isinstance
(
item
,
Tensor
)
and
item
.
size
(
0
)
==
data
.
num_nodes
:
if
isinstance
(
item
,
Tensor
)
and
item
.
size
(
0
)
==
data
.
num_nodes
:
data
[
key
]
=
item
[
perm
]
data
[
key
]
=
item
[
perm
]
if
isinstance
(
item
,
Tensor
)
and
item
.
size
(
0
)
==
data
.
num_edges
:
el
if
isinstance
(
item
,
Tensor
)
and
item
.
size
(
0
)
==
data
.
num_edges
:
raise
NotImplementedError
raise
NotImplementedError
if
isinstance
(
item
,
SparseTensor
):
el
if
isinstance
(
item
,
SparseTensor
):
data
[
key
]
=
permute
(
item
,
perm
,
log
=
False
)
data
[
key
]
=
permute
(
item
,
perm
,
log
=
False
)
else
:
else
:
data
=
data
.
permute
(
perm
)
data
=
data
.
permute
(
perm
)
...
...
torch_geometric_autoscale/models/base.py
View file @
c427300e
...
@@ -18,7 +18,7 @@ class ScalableGNN(torch.nn.Module):
...
@@ -18,7 +18,7 @@ class ScalableGNN(torch.nn.Module):
self
.
num_nodes
=
num_nodes
self
.
num_nodes
=
num_nodes
self
.
hidden_channels
=
hidden_channels
self
.
hidden_channels
=
hidden_channels
self
.
num_layers
=
num_layers
self
.
num_layers
=
num_layers
self
.
pool_size
=
num_layers
if
pool_size
is
None
else
pool_size
self
.
pool_size
=
num_layers
-
1
if
pool_size
is
None
else
pool_size
self
.
buffer_size
=
buffer_size
self
.
buffer_size
=
buffer_size
self
.
histories
=
torch
.
nn
.
ModuleList
([
self
.
histories
=
torch
.
nn
.
ModuleList
([
...
@@ -59,13 +59,17 @@ class ScalableGNN(torch.nn.Module):
...
@@ -59,13 +59,17 @@ class ScalableGNN(torch.nn.Module):
for
history
in
self
.
histories
:
for
history
in
self
.
histories
:
history
.
reset_parameters
()
history
.
reset_parameters
()
def
__call__
(
self
,
x
:
Optional
[
Tensor
]
=
None
,
def
__call__
(
adj_t
:
Optional
[
SparseTensor
]
=
None
,
self
,
batch_size
:
Optional
[
int
]
=
None
,
x
:
Optional
[
Tensor
]
=
None
,
n_id
:
Optional
[
Tensor
]
=
None
,
adj_t
:
Optional
[
SparseTensor
]
=
None
,
offset
:
Optional
[
Tensor
]
=
None
,
batch_size
:
Optional
[
int
]
=
None
,
count
:
Optional
[
Tensor
]
=
None
,
loader
=
None
,
n_id
:
Optional
[
Tensor
]
=
None
,
**
kwargs
)
->
Tensor
:
offset
:
Optional
[
Tensor
]
=
None
,
count
:
Optional
[
Tensor
]
=
None
,
loader
=
None
,
**
kwargs
,
)
->
Tensor
:
if
loader
is
not
None
:
if
loader
is
not
None
:
return
self
.
mini_inference
(
loader
)
return
self
.
mini_inference
(
loader
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment