Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pyg_autoscale
Commits
b8ee6962
Commit
b8ee6962
authored
Jun 08, 2021
by
rusty1s
Browse files
clean up loader
parent
dc5f7414
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
85 additions
and
123 deletions
+85
-123
torch_geometric_autoscale/loader.py
torch_geometric_autoscale/loader.py
+84
-121
torch_geometric_autoscale/models/base.py
torch_geometric_autoscale/models/base.py
+1
-2
No files found.
torch_geometric_autoscale/loader.py
View file @
b8ee6962
from
typing
import
Optional
,
Union
,
Tuple
,
NamedTuple
,
List
from
typing
import
NamedTuple
,
List
,
Tuple
import
time
...
...
@@ -12,11 +12,11 @@ relabel_fn = torch.ops.torch_geometric_autoscale.relabel_one_hop
class
SubData
(
NamedTuple
):
data
:
Union
[
Data
,
SparseTensor
]
data
:
Data
batch_size
:
int
n_id
:
Tensor
offset
:
Optional
[
Tensor
]
count
:
Optional
[
Tensor
]
offset
:
Tensor
count
:
Tensor
def
to
(
self
,
*
args
,
**
kwargs
):
return
SubData
(
self
.
data
.
to
(
*
args
,
**
kwargs
),
self
.
batch_size
,
...
...
@@ -24,138 +24,101 @@ class SubData(NamedTuple):
class
SubgraphLoader
(
DataLoader
):
r
"""A simple subgraph loader that, given a randomly sampled or
pre-partioned batch of nodes, returns the subgraph of this batch
(including its 1-hop neighbors)."""
def
__init__
(
self
,
data
:
Union
[
Data
,
SparseTensor
],
ptr
:
Optional
[
Tensor
]
=
None
,
batch_size
:
int
=
1
,
bipartite
:
bool
=
True
,
log
:
bool
=
True
,
**
kwargs
,
):
self
.
__data__
=
None
if
isinstance
(
data
,
SparseTensor
)
else
data
self
.
__adj_t__
=
data
if
isinstance
(
data
,
SparseTensor
)
else
data
.
adj_t
self
.
__N__
=
self
.
__adj_t__
.
size
(
1
)
self
.
__E__
=
self
.
__adj_t__
.
nnz
()
self
.
__ptr__
=
ptr
self
.
__bipartite__
=
bipartite
if
ptr
is
not
None
:
n_id
=
torch
.
arange
(
self
.
__N__
)
r
"""A simple subgraph loader that, given a pre-partioned :obj:`data` object,
generates subgraphs (including its 1-hop neighbors) from mini-batches in
:obj:`ptr`."""
def
__init__
(
self
,
data
:
Data
,
ptr
:
Tensor
,
batch_size
:
int
=
1
,
bipartite
:
bool
=
True
,
log
:
bool
=
True
,
**
kwargs
):
self
.
data
=
data
self
.
ptr
=
ptr
self
.
bipartite
=
bipartite
self
.
log
=
log
n_id
=
torch
.
arange
(
data
.
num_nodes
)
batches
=
n_id
.
split
((
ptr
[
1
:]
-
ptr
[:
-
1
]).
tolist
())
batches
=
[(
i
,
batches
[
i
])
for
i
in
range
(
len
(
batches
))]
if
batch_size
>
1
:
super
(
SubgraphLoader
,
self
).
__init__
(
batches
,
collate_fn
=
self
.
sample_partitions
,
batch_size
=
batch_size
,
**
kwargs
)
else
:
super
(
SubgraphLoader
,
self
).
__init__
(
batches
,
batch_size
=
batch_size
,
collate_fn
=
self
.
compute_subgraph
,
**
kwargs
,
)
else
:
# If `batch_size=1`, we pre-process the subgraph generation:
if
log
:
t
=
time
.
perf_counter
()
print
(
'Pre-processing subgraphs...'
,
end
=
' '
,
flush
=
True
)
data_list
=
[
data
for
data
in
DataLoader
(
batches
,
collate_fn
=
self
.
sample_partitions
,
batch_size
=
batch_size
,
**
kwargs
)
]
data_list
=
list
(
DataLoader
(
batches
,
collate_fn
=
self
.
compute_subgraph
,
batch_size
=
batch_size
,
**
kwargs
))
if
log
:
print
(
f
'Done! [
{
time
.
perf_counter
()
-
t
:.
2
f
}
s]'
)
super
(
SubgraphLoader
,
self
).
__init__
(
data_list
,
batch_size
=
1
,
collate_fn
=
lambda
x
:
x
[
0
],
**
kwargs
)
else
:
super
(
SubgraphLoader
,
self
).
__init__
(
range
(
self
.
__N__
),
collate_fn
=
self
.
sample_nodes
,
batch_size
=
batch_size
,
**
kwargs
)
def
sample_partitions
(
self
,
batches
:
List
[
Tuple
[
int
,
Tensor
]])
->
SubData
:
ptr_ids
,
n_ids
=
zip
(
*
batches
)
super
(
SubgraphLoader
,
self
).
__init__
(
data_list
,
batch_size
=
batch_size
,
collate_fn
=
lambda
x
:
x
[
0
],
**
kwargs
,
)
def
compute_subgraph
(
self
,
batches
:
List
[
Tuple
[
int
,
Tensor
]])
->
SubData
:
batch_ids
,
n_ids
=
zip
(
*
batches
)
n_id
=
torch
.
cat
(
n_ids
,
dim
=
0
)
batch_size
=
n_id
.
numel
()
ptr_id
=
torch
.
tensor
(
ptr_ids
)
offset
=
self
.
__ptr__
[
ptr_id
]
count
=
self
.
__ptr__
[
ptr_id
.
add_
(
1
)].
sub_
(
offset
)
batch_id
=
torch
.
tensor
(
batch_ids
)
rowptr
,
col
,
value
=
self
.
__adj_t__
.
csr
()
rowptr
,
col
,
value
,
n_id
=
relabel_fn
(
rowptr
,
col
,
value
,
n_id
,
self
.
__bipartite__
)
adj_t
=
SparseTensor
(
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
rowptr
.
numel
()
-
1
,
n_id
.
numel
()),
is_sorted
=
True
)
if
self
.
__data__
is
None
:
return
SubData
(
adj_t
,
batch_size
,
n_id
,
offset
,
count
)
data
=
self
.
__data__
.
__class__
(
adj_t
=
adj_t
)
for
key
,
item
in
self
.
__data__
:
if
isinstance
(
item
,
Tensor
)
and
item
.
size
(
0
)
==
self
.
__N__
:
data
[
key
]
=
item
.
index_select
(
0
,
n_id
)
elif
isinstance
(
item
,
SparseTensor
):
pass
else
:
data
[
key
]
=
item
return
SubData
(
data
,
batch_size
,
n_id
,
offset
,
count
)
def
sample_nodes
(
self
,
n_ids
:
List
[
int
])
->
SubData
:
n_id
=
torch
.
tensor
(
n_ids
)
# We collect the in-mini-batch size (`batch_size`), the offset of each
# partition in the mini-batch (`offset`), and the number of nodes in
# each partition (`count`)
batch_size
=
n_id
.
numel
()
offset
=
self
.
ptr
[
batch_id
]
count
=
self
.
ptr
[
batch_id
.
add_
(
1
)].
sub_
(
offset
)
rowptr
,
col
,
value
=
self
.
__
adj_t
__
.
csr
()
rowptr
,
col
,
value
=
self
.
data
.
adj_t
.
csr
()
rowptr
,
col
,
value
,
n_id
=
relabel_fn
(
rowptr
,
col
,
value
,
n_id
,
self
.
__
bipartite
__
)
self
.
bipartite
)
adj_t
=
SparseTensor
(
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
rowptr
.
numel
()
-
1
,
n_id
.
numel
()),
is_sorted
=
True
)
if
self
.
__data__
is
None
:
return
SubData
(
adj_t
,
batch_size
,
n_id
,
None
,
None
)
data
=
self
.
__data__
.
__class__
(
adj_t
=
adj_t
)
for
key
,
item
in
self
.
__data__
:
if
isinstance
(
item
,
Tensor
)
and
item
.
size
(
0
)
==
self
.
__N__
:
data
[
key
]
=
item
.
index_select
(
0
,
n_id
)
elif
isinstance
(
item
,
SparseTensor
):
pass
else
:
data
[
key
]
=
item
data
=
self
.
data
.
__class__
(
adj_t
=
adj_t
)
for
k
,
v
in
self
.
data
:
if
isinstance
(
v
,
Tensor
)
and
v
.
size
(
0
)
==
self
.
data
.
num_nodes
:
data
[
k
]
=
v
.
index_select
(
0
,
n_id
)
return
SubData
(
data
,
batch_size
,
n_id
,
None
,
None
)
return
SubData
(
data
,
batch_size
,
n_id
,
offset
,
count
)
def
__repr__
(
self
):
return
f
'
{
self
.
__class__
.
__name__
}
()'
class
EvalSubgraphLoader
(
SubgraphLoader
):
def
__init__
(
self
,
data
:
Union
[
Data
,
SparseTensor
],
ptr
:
Optional
[
Tensor
]
=
None
,
batch_size
:
int
=
1
,
bipartite
:
bool
=
True
,
log
:
bool
=
True
,
**
kwargs
,
):
r
"""A simple subgraph loader that, given a pre-partioned :obj:`data` object,
generates subgraphs (including its 1-hop neighbors) from mini-batches in
:obj:`ptr`.
In contrast to :class:`SubgraphLoader`, this loader does not generate
subgraphs from randomly sampled mini-batches, and should therefore only be
used for evaluation.
"""
def
__init__
(
self
,
data
:
Data
,
ptr
:
Tensor
,
batch_size
:
int
=
1
,
bipartite
:
bool
=
True
,
log
:
bool
=
True
,
**
kwargs
):
num_nodes
=
ptr
[
-
1
]
ptr
=
ptr
[::
batch_size
]
if
int
(
ptr
[
-
1
])
!=
int
(
num_nodes
):
ptr
=
torch
.
cat
([
ptr
,
num_nodes
.
unsqueeze
(
0
)],
dim
=
0
)
super
(
EvalSubgraphLoader
,
self
).
__init__
(
data
,
ptr
,
1
,
bipartite
,
log
,
num_workers
=
0
,
shuffle
=
False
,
**
kwargs
)
if
int
(
ptr
[
-
1
])
!=
data
.
num_nodes
:
ptr
=
torch
.
cat
([
ptr
,
torch
.
tensor
(
data
.
num_nodes
)],
dim
=
0
)
super
(
EvalSubgraphLoader
,
self
).
__init__
(
data
=
data
,
ptr
=
ptr
,
batch_size
=
1
,
bipartite
=
bipartite
,
log
=
log
,
shuffle
=
False
,
num_workers
=
0
,
**
kwargs
,
)
torch_geometric_autoscale/models/base.py
View file @
b8ee6962
...
...
@@ -88,8 +88,7 @@ class ScalableGNN(torch.nn.Module):
for
hist
in
self
.
histories
:
self
.
pool
.
async_pull
(
hist
.
emb
,
None
,
None
,
n_id
[
batch_size
:])
out
=
self
.
forward
(
x
=
x
,
adj_t
=
adj_t
,
batch_size
=
batch_size
,
n_id
=
n_id
,
offset
=
offset
,
count
=
count
,
**
kwargs
)
out
=
self
.
forward
(
x
,
adj_t
,
batch_size
,
n_id
,
offset
,
count
,
**
kwargs
)
if
self
.
_async
:
for
hist
in
self
.
histories
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment