Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pyg_autoscale
Commits
a73bb262
Commit
a73bb262
authored
Jun 08, 2021
by
rusty1s
Browse files
doc
parent
6325fa72
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
4 deletions
+16
-4
torch_geometric_autoscale/models/base.py
torch_geometric_autoscale/models/base.py
+16
-4
No files found.
torch_geometric_autoscale/models/base.py
View file @
a73bb262
...
@@ -29,6 +29,7 @@ class ScalableGNN(torch.nn.Module):
...
@@ -29,6 +29,7 @@ class ScalableGNN(torch.nn.Module):
self
.
pool
:
Optional
[
AsyncIOPool
]
=
None
self
.
pool
:
Optional
[
AsyncIOPool
]
=
None
self
.
_async
=
False
self
.
_async
=
False
self
.
__out
:
Optional
[
Tensor
]
=
None
@
property
@
property
def
emb_device
(
self
):
def
emb_device
(
self
):
...
@@ -135,20 +136,28 @@ class ScalableGNN(torch.nn.Module):
...
@@ -135,20 +136,28 @@ class ScalableGNN(torch.nn.Module):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
mini_inference
(
self
,
loader
:
SubgraphLoader
)
->
Tensor
:
def
mini_inference
(
self
,
loader
:
SubgraphLoader
)
->
Tensor
:
loader
=
[
data
+
({},
)
for
data
in
loader
]
# We iterate over the loader in a layer-wise fashsion.
# In order to re-use some intermediate representations, we maintain a
# `state` dictionary for each individual mini-batch.
for
batch
,
batch_size
,
n_id
,
offset
,
count
,
state
in
loader
:
loader
=
[
sub_data
+
({},
)
for
sub_data
in
loader
]
x
=
batch
.
x
.
to
(
self
.
device
)
adj_t
=
batch
.
adj_t
.
to
(
self
.
device
)
# We push the outputs of the first layer to the history:
for
data
,
batch_size
,
n_id
,
offset
,
count
,
state
in
loader
:
x
=
data
.
x
.
to
(
self
.
device
)
adj_t
=
data
.
adj_t
.
to
(
self
.
device
)
out
=
self
.
forward_layer
(
0
,
x
,
adj_t
,
state
)[:
batch_size
]
out
=
self
.
forward_layer
(
0
,
x
,
adj_t
,
state
)[:
batch_size
]
self
.
pool
.
async_push
(
out
,
offset
,
count
,
self
.
histories
[
0
].
emb
)
self
.
pool
.
async_push
(
out
,
offset
,
count
,
self
.
histories
[
0
].
emb
)
self
.
pool
.
synchronize_push
()
self
.
pool
.
synchronize_push
()
for
i
in
range
(
1
,
len
(
self
.
histories
)):
for
i
in
range
(
1
,
len
(
self
.
histories
)):
# Pull the complete layer-wise history:
for
_
,
batch_size
,
n_id
,
offset
,
count
,
_
in
loader
:
for
_
,
batch_size
,
n_id
,
offset
,
count
,
_
in
loader
:
self
.
pool
.
async_pull
(
self
.
histories
[
i
-
1
].
emb
,
offset
,
count
,
self
.
pool
.
async_pull
(
self
.
histories
[
i
-
1
].
emb
,
offset
,
count
,
n_id
[
batch_size
:])
n_id
[
batch_size
:])
# Compute new output embeddings one-by-one and start pushing them
# to the history.
for
batch
,
batch_size
,
n_id
,
offset
,
count
,
state
in
loader
:
for
batch
,
batch_size
,
n_id
,
offset
,
count
,
state
in
loader
:
adj_t
=
batch
.
adj_t
.
to
(
self
.
device
)
adj_t
=
batch
.
adj_t
.
to
(
self
.
device
)
x
=
self
.
pool
.
synchronize_pull
()[:
n_id
.
numel
()]
x
=
self
.
pool
.
synchronize_pull
()[:
n_id
.
numel
()]
...
@@ -157,10 +166,13 @@ class ScalableGNN(torch.nn.Module):
...
@@ -157,10 +166,13 @@ class ScalableGNN(torch.nn.Module):
self
.
pool
.
free_pull
()
self
.
pool
.
free_pull
()
self
.
pool
.
synchronize_push
()
self
.
pool
.
synchronize_push
()
# We pull the histories from the last layer:
for
_
,
batch_size
,
n_id
,
offset
,
count
,
_
in
loader
:
for
_
,
batch_size
,
n_id
,
offset
,
count
,
_
in
loader
:
self
.
pool
.
async_pull
(
self
.
histories
[
-
1
].
emb
,
offset
,
count
,
self
.
pool
.
async_pull
(
self
.
histories
[
-
1
].
emb
,
offset
,
count
,
n_id
[
batch_size
:])
n_id
[
batch_size
:])
# And compute final output embeddings, which we write into a private
# output embedding matrix:
for
batch
,
batch_size
,
n_id
,
offset
,
count
,
state
in
loader
:
for
batch
,
batch_size
,
n_id
,
offset
,
count
,
state
in
loader
:
adj_t
=
batch
.
adj_t
.
to
(
self
.
device
)
adj_t
=
batch
.
adj_t
.
to
(
self
.
device
)
x
=
self
.
pool
.
synchronize_pull
()[:
n_id
.
numel
()]
x
=
self
.
pool
.
synchronize_pull
()[:
n_id
.
numel
()]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment