Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pyg_autoscale
Commits
727f3279
Commit
727f3279
authored
Jun 08, 2021
by
rusty1s
Browse files
clean up history
parent
b8ee6962
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
40 deletions
+24
-40
torch_geometric_autoscale/history.py
torch_geometric_autoscale/history.py
+18
-18
torch_geometric_autoscale/loader.py
torch_geometric_autoscale/loader.py
+6
-22
No files found.
torch_geometric_autoscale/history.py
View file @
727f3279
...
...
@@ -5,10 +5,9 @@ from torch import Tensor
class
History
(
torch
.
nn
.
Module
):
r
"""A node embedding storage module with asynchronous I/O support between
devices."""
r
"""A historical embedding storage module."""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
device
=
None
):
super
(
History
,
self
).
__init__
()
super
().
__init__
()
self
.
num_embeddings
=
num_embeddings
self
.
embedding_dim
=
embedding_dim
...
...
@@ -25,37 +24,38 @@ class History(torch.nn.Module):
self
.
emb
.
fill_
(
0
)
def
_apply
(
self
,
fn
):
# Set the `_device` of the module without transfering `self.emb`.
self
.
_device
=
fn
(
torch
.
zeros
(
1
)).
device
return
self
@
torch
.
no_grad
()
def
pull
(
self
,
index
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
def
pull
(
self
,
n_id
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
out
=
self
.
emb
if
index
is
not
None
:
assert
index
.
device
==
self
.
emb
.
device
out
=
out
.
index_select
(
0
,
index
)
if
n_id
is
not
None
:
assert
n_id
.
device
==
self
.
emb
.
device
out
=
out
.
index_select
(
0
,
n_id
)
return
out
.
to
(
device
=
self
.
_device
)
@
torch
.
no_grad
()
def
push
(
self
,
x
,
index
:
Optional
[
Tensor
]
=
None
,
def
push
(
self
,
x
,
n_id
:
Optional
[
Tensor
]
=
None
,
offset
:
Optional
[
Tensor
]
=
None
,
count
:
Optional
[
Tensor
]
=
None
):
if
index
is
None
and
x
.
size
(
0
)
!=
self
.
num_embeddings
:
if
n_id
is
None
and
x
.
size
(
0
)
!=
self
.
num_embeddings
:
raise
ValueError
elif
index
is
None
and
x
.
size
(
0
)
==
self
.
num_embeddings
:
elif
n_id
is
None
and
x
.
size
(
0
)
==
self
.
num_embeddings
:
self
.
emb
.
copy_
(
x
)
elif
index
is
not
None
and
(
offset
is
None
or
count
is
None
)
:
assert
index
.
device
==
self
.
emb
.
device
self
.
emb
[
index
]
=
x
.
to
(
self
.
emb
.
device
)
elif
offset
is
None
or
count
is
None
:
assert
n_id
.
device
==
self
.
emb
.
device
self
.
emb
[
n_id
]
=
x
.
to
(
self
.
emb
.
device
)
else
:
x
_o
=
0
else
:
# Push in chunks:
src
_o
=
0
x
=
x
.
to
(
self
.
emb
.
device
)
for
o
,
c
,
in
zip
(
offset
.
tolist
(),
count
.
tolist
()):
self
.
emb
[
o
:
o
+
c
]
=
x
[
x_o
:
x
_o
+
c
]
x
_o
+=
c
for
dst_
o
,
c
,
in
zip
(
offset
.
tolist
(),
count
.
tolist
()):
self
.
emb
[
dst_o
:
dst_
o
+
c
]
=
x
[
src_o
:
src
_o
+
c
]
src
_o
+=
c
def
forward
(
self
,
*
args
,
**
kwargs
):
""""""
...
...
torch_geometric_autoscale/loader.py
View file @
727f3279
...
...
@@ -40,12 +40,8 @@ class SubgraphLoader(DataLoader):
batches
=
[(
i
,
batches
[
i
])
for
i
in
range
(
len
(
batches
))]
if
batch_size
>
1
:
super
(
SubgraphLoader
,
self
).
__init__
(
batches
,
batch_size
=
batch_size
,
collate_fn
=
self
.
compute_subgraph
,
**
kwargs
,
)
super
().
__init__
(
batches
,
batch_size
=
batch_size
,
collate_fn
=
self
.
compute_subgraph
,
**
kwargs
)
else
:
# If `batch_size=1`, we pre-process the subgraph generation:
if
log
:
...
...
@@ -59,12 +55,8 @@ class SubgraphLoader(DataLoader):
if
log
:
print
(
f
'Done! [
{
time
.
perf_counter
()
-
t
:.
2
f
}
s]'
)
super
(
SubgraphLoader
,
self
).
__init__
(
data_list
,
batch_size
=
batch_size
,
collate_fn
=
lambda
x
:
x
[
0
],
**
kwargs
,
)
super
().
__init__
(
data_list
,
batch_size
=
batch_size
,
collate_fn
=
lambda
x
:
x
[
0
],
**
kwargs
)
def
compute_subgraph
(
self
,
batches
:
List
[
Tuple
[
int
,
Tensor
]])
->
SubData
:
batch_ids
,
n_ids
=
zip
(
*
batches
)
...
...
@@ -112,13 +104,5 @@ class EvalSubgraphLoader(SubgraphLoader):
if
int
(
ptr
[
-
1
])
!=
data
.
num_nodes
:
ptr
=
torch
.
cat
([
ptr
,
torch
.
tensor
(
data
.
num_nodes
)],
dim
=
0
)
super
(
EvalSubgraphLoader
,
self
).
__init__
(
data
=
data
,
ptr
=
ptr
,
batch_size
=
1
,
bipartite
=
bipartite
,
log
=
log
,
shuffle
=
False
,
num_workers
=
0
,
**
kwargs
,
)
super
().
__init__
(
data
=
data
,
ptr
=
ptr
,
batch_size
=
1
,
bipartite
=
bipartite
,
log
=
log
,
shuffle
=
False
,
num_workers
=
0
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment