Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
df3e7f1b
Unverified
Commit
df3e7f1b
authored
Dec 05, 2023
by
xiangyuzhi
Committed by
GitHub
Dec 05, 2023
Browse files
[Sparse] Polish sparse example for graphbolt (#6691)
parent
3cdc37cc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
27 deletions
+35
-27
examples/sampling/graphbolt/sparse/graphsage.py
examples/sampling/graphbolt/sparse/graphsage.py
+35
-27
No files found.
examples/sampling/graphbolt/sparse/graphsage.py
View file @
df3e7f1b
...
@@ -34,9 +34,8 @@ import torch
...
@@ -34,9 +34,8 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torchmetrics.functional
as
MF
import
torchmetrics.functional
as
MF
from
dgl.data
import
AsNodePredDataset
from
dgl.graphbolt.subgraph_sampler
import
SubgraphSampler
from
ogb.nodeproppred
import
DglNodePropPredDataset
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
IterableWrapper
class
SAGEConv
(
nn
.
Module
):
class
SAGEConv
(
nn
.
Module
):
...
@@ -102,9 +101,21 @@ class SAGE(nn.Module):
...
@@ -102,9 +101,21 @@ class SAGE(nn.Module):
return
hidden_x
return
hidden_x
def
multilayer_sample
(
A
,
fanouts
,
minibatch
):
@
functional_datapipe
(
"sample_sparse_neighbor"
)
class
SparseNeighborSampler
(
SubgraphSampler
):
def
__init__
(
self
,
datapipe
,
matrix
,
fanouts
):
super
().
__init__
(
datapipe
)
self
.
matrix
=
matrix
# Convert fanouts to a list of tensors.
self
.
fanouts
=
[]
for
fanout
in
fanouts
:
if
not
isinstance
(
fanout
,
torch
.
Tensor
):
fanout
=
torch
.
LongTensor
([
int
(
fanout
)])
self
.
fanouts
.
insert
(
0
,
fanout
)
def
_sample_subgraphs
(
self
,
seeds
):
sampled_matrices
=
[]
sampled_matrices
=
[]
src
=
minibatch
.
seed_node
s
src
=
seed
s
#####################################################################
#####################################################################
# (HIGHLIGHT) Using the sparse sample operator to preform random
# (HIGHLIGHT) Using the sparse sample operator to preform random
...
@@ -112,18 +123,15 @@ def multilayer_sample(A, fanouts, minibatch):
...
@@ -112,18 +123,15 @@ def multilayer_sample(A, fanouts, minibatch):
# compact operator is then employed to compact and relabel the sampled
# compact operator is then employed to compact and relabel the sampled
# matrix, resulting in the sampled matrix and the relabel index.
# matrix, resulting in the sampled matrix and the relabel index.
#####################################################################
#####################################################################
for
fanout
in
self
.
fanouts
:
for
fanout
in
fanouts
:
# Sample neighbors.
# Sample neighbors.
sampled_matrix
=
A
.
sample
(
1
,
fanout
,
ids
=
src
).
coalesce
()
sampled_matrix
=
self
.
matrix
.
sample
(
1
,
fanout
,
ids
=
src
).
coalesce
()
# Compact the sampled matrix.
# Compact the sampled matrix.
compacted_mat
,
row_ids
=
sampled_matrix
.
compact
(
0
)
compacted_mat
,
row_ids
=
sampled_matrix
.
compact
(
0
)
sampled_matrices
.
insert
(
0
,
compacted_mat
)
sampled_matrices
.
insert
(
0
,
compacted_mat
)
src
=
row_ids
src
=
row_ids
minibatch
.
input_nodes
=
src
return
src
,
sampled_matrices
minibatch
.
sampled_subgraphs
=
sampled_matrices
return
minibatch
############################################################################
############################################################################
...
@@ -132,11 +140,11 @@ def multilayer_sample(A, fanouts, minibatch):
...
@@ -132,11 +140,11 @@ def multilayer_sample(A, fanouts, minibatch):
def
create_dataloader
(
A
,
fanouts
,
ids
,
features
,
device
):
def
create_dataloader
(
A
,
fanouts
,
ids
,
features
,
device
):
datapipe
=
gb
.
ItemSampler
(
ids
,
batch_size
=
1024
)
datapipe
=
gb
.
ItemSampler
(
ids
,
batch_size
=
1024
)
# Customize graphbolt sampler by sparse.
# Customize graphbolt sampler by sparse.
datapipe
=
datapipe
.
map
(
partial
(
multilayer_sample
,
A
,
fanouts
)
)
datapipe
=
datapipe
.
sample_sparse_neighbor
(
A
,
fanouts
)
# Use grapbolt to fetch features.
# Use grapbolt to fetch features.
datapipe
=
datapipe
.
fetch_feature
(
features
,
node_feature_keys
=
[
"feat"
])
datapipe
=
datapipe
.
fetch_feature
(
features
,
node_feature_keys
=
[
"feat"
])
datapipe
=
datapipe
.
copy_to
(
device
)
datapipe
=
datapipe
.
copy_to
(
device
)
dataloader
=
gb
.
DataLoader
(
datapipe
,
num_workers
=
4
)
dataloader
=
gb
.
DataLoader
(
datapipe
)
return
dataloader
return
dataloader
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment