Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
12fab559
"src/vscode:/vscode.git/clone" did not exist on "c6714fc3bfc4b8ccba08ea68cebb095f2af1d75e"
Unverified
Commit
12fab559
authored
Mar 06, 2024
by
Ramon Zhou
Committed by
GitHub
Mar 06, 2024
Browse files
[GraphBolt][PyG] Add more attributes in `to_pyg_data` (#7196)
parent
a6505e86
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
36 deletions
+37
-36
examples/sampling/pyg/node_classification.py
examples/sampling/pyg/node_classification.py
+7
-7
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+14
-0
tests/python/pytorch/graphbolt/test_minibatch.py
tests/python/pytorch/graphbolt/test_minibatch.py
+16
-29
No files found.
examples/sampling/pyg/node_classification.py
View file @
12fab559
...
@@ -88,18 +88,18 @@ class GraphSAGE(torch.nn.Module):
...
@@ -88,18 +88,18 @@ class GraphSAGE(torch.nn.Module):
x
=
F
.
dropout
(
x
,
p
=
0.5
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
0.5
,
training
=
self
.
training
)
return
x
return
x
def
inference
(
self
,
args
,
dataloader
,
x_all
,
device
):
def
inference
(
self
,
dataloader
,
x_all
,
device
):
"""Conduct layer-wise inference to get all the node embeddings."""
"""Conduct layer-wise inference to get all the node embeddings."""
for
i
,
layer
in
tqdm
(
enumerate
(
self
.
layers
),
"inference"
):
for
i
,
layer
in
tqdm
(
enumerate
(
self
.
layers
),
"inference"
):
xs
=
[]
xs
=
[]
for
minibatch
in
dataloader
:
for
minibatch
in
dataloader
:
# Call `to_pyg_data` to convert GB Minibatch to PyG Data.
# Call `to_pyg_data` to convert GB Minibatch to PyG Data.
pyg_data
=
minibatch
.
to_pyg_data
()
pyg_data
=
minibatch
.
to_pyg_data
()
n_id
s
=
minibatch
.
node_ids
()
.
to
(
"cpu"
)
n_id
=
pyg_data
.
n_id
.
to
(
"cpu"
)
x
=
x_all
[
n_id
s
].
to
(
device
)
x
=
x_all
[
n_id
].
to
(
device
)
edge_index
=
pyg_data
.
edge_index
edge_index
=
pyg_data
.
edge_index
x
=
layer
(
x
,
edge_index
)
x
=
layer
(
x
,
edge_index
)
x
=
x
[:
4
*
args
.
batch_size
]
x
=
x
[:
pyg_data
.
batch_size
]
if
i
!=
len
(
self
.
layers
)
-
1
:
if
i
!=
len
(
self
.
layers
)
-
1
:
x
=
x
.
relu
()
x
=
x
.
relu
()
xs
.
append
(
x
.
cpu
())
xs
.
append
(
x
.
cpu
())
...
@@ -185,11 +185,11 @@ def evaluate(model, dataloader, num_classes):
...
@@ -185,11 +185,11 @@ def evaluate(model, dataloader, num_classes):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
layerwise_infer
(
def
layerwise_infer
(
model
,
args
,
infer_dataloader
,
test_set
,
feature
,
num_classes
,
device
model
,
infer_dataloader
,
test_set
,
feature
,
num_classes
,
device
):
):
model
.
eval
()
model
.
eval
()
features
=
feature
.
read
(
"node"
,
None
,
"feat"
)
features
=
feature
.
read
(
"node"
,
None
,
"feat"
)
pred
=
model
.
inference
(
args
,
infer_dataloader
,
features
,
device
)
pred
=
model
.
inference
(
infer_dataloader
,
features
,
device
)
pred
=
pred
[
test_set
.
_items
[
0
]]
pred
=
pred
[
test_set
.
_items
[
0
]]
label
=
test_set
.
_items
[
1
].
to
(
pred
.
device
)
label
=
test_set
.
_items
[
1
].
to
(
pred
.
device
)
...
@@ -271,7 +271,7 @@ def main():
...
@@ -271,7 +271,7 @@ def main():
f
"Valid Accuracy:
{
valid_accuracy
:.
4
f
}
"
f
"Valid Accuracy:
{
valid_accuracy
:.
4
f
}
"
)
)
test_accuracy
=
layerwise_infer
(
test_accuracy
=
layerwise_infer
(
model
,
args
,
infer_dataloader
,
test_set
,
feature
,
num_classes
,
device
model
,
infer_dataloader
,
test_set
,
feature
,
num_classes
,
device
)
)
print
(
f
"Test Accuracy:
{
test_accuracy
:.
4
f
}
"
)
print
(
f
"Test Accuracy:
{
test_accuracy
:.
4
f
}
"
)
...
...
python/dgl/graphbolt/minibatch.py
View file @
12fab559
...
@@ -526,10 +526,24 @@ class MiniBatch:
...
@@ -526,10 +526,24 @@ class MiniBatch:
),
"`to_pyg_data` only supports single feature homogeneous graph."
),
"`to_pyg_data` only supports single feature homogeneous graph."
node_features
=
next
(
iter
(
self
.
node_features
.
values
()))
node_features
=
next
(
iter
(
self
.
node_features
.
values
()))
if
self
.
seed_nodes
is
not
None
:
if
isinstance
(
self
.
seed_nodes
,
Dict
):
batch_size
=
len
(
next
(
iter
(
self
.
seed_nodes
.
values
())))
else
:
batch_size
=
len
(
self
.
seed_nodes
)
elif
self
.
node_pairs
is
not
None
:
if
isinstance
(
self
.
node_pairs
,
Dict
):
batch_size
=
len
(
next
(
iter
(
self
.
node_pairs
.
values
()))[
0
])
else
:
batch_size
=
len
(
self
.
node_pairs
[
0
])
else
:
batch_size
=
None
pyg_data
=
Data
(
pyg_data
=
Data
(
x
=
node_features
,
x
=
node_features
,
edge_index
=
edge_index
,
edge_index
=
edge_index
,
y
=
self
.
labels
,
y
=
self
.
labels
,
batch_size
=
batch_size
,
n_id
=
self
.
node_ids
(),
)
)
return
pyg_data
return
pyg_data
...
...
tests/python/pytorch/graphbolt/test_minibatch.py
View file @
12fab559
...
@@ -869,40 +869,27 @@ def test_dgl_link_predication_hetero(mode):
...
@@ -869,40 +869,27 @@ def test_dgl_link_predication_hetero(mode):
def
test_to_pyg_data
():
def
test_to_pyg_data
():
test_subgraph_a
=
gb
.
SampledSubgraphImpl
(
test_minibatch
=
create_homo_minibatch
()
sampled_csc
=
gb
.
CSCFormatBase
(
test_minibatch
.
seed_nodes
=
torch
.
tensor
([
0
,
1
])
indptr
=
torch
.
tensor
([
0
,
1
,
3
,
5
,
6
]),
test_minibatch
.
labels
=
torch
.
tensor
([
7
,
8
])
indices
=
torch
.
tensor
([
0
,
1
,
2
,
2
,
1
,
2
]),
),
original_column_node_ids
=
torch
.
tensor
([
10
,
11
,
12
,
13
]),
original_row_node_ids
=
torch
.
tensor
([
19
,
20
,
21
,
22
,
25
,
30
]),
original_edge_ids
=
torch
.
tensor
([
10
,
11
,
12
,
13
]),
)
test_subgraph_b
=
gb
.
SampledSubgraphImpl
(
sampled_csc
=
gb
.
CSCFormatBase
(
indptr
=
torch
.
tensor
([
0
,
1
,
3
]),
indices
=
torch
.
tensor
([
1
,
2
,
0
]),
),
original_row_node_ids
=
torch
.
tensor
([
10
,
11
,
12
]),
original_edge_ids
=
torch
.
tensor
([
10
,
15
,
17
]),
original_column_node_ids
=
torch
.
tensor
([
10
,
11
]),
)
expected_edge_index
=
torch
.
tensor
(
expected_edge_index
=
torch
.
tensor
(
[[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
],
[
0
,
1
,
0
,
1
,
2
,
1
,
2
,
3
]]
[[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
2
],
[
0
,
1
,
0
,
1
,
2
,
0
,
1
,
2
,
3
]]
)
expected_node_features
=
torch
.
tensor
([[
1
],
[
2
],
[
3
],
[
4
]])
expected_labels
=
torch
.
tensor
([
0
,
1
])
test_minibatch
=
gb
.
MiniBatch
(
sampled_subgraphs
=
[
test_subgraph_a
,
test_subgraph_b
],
node_features
=
{
"feat"
:
expected_node_features
},
labels
=
expected_labels
,
)
)
expected_node_features
=
next
(
iter
(
test_minibatch
.
node_features
.
values
()))
expected_labels
=
torch
.
tensor
([
7
,
8
])
expected_batch_size
=
2
expected_n_id
=
torch
.
tensor
([
10
,
11
,
12
,
13
])
pyg_data
=
test_minibatch
.
to_pyg_data
()
pyg_data
=
test_minibatch
.
to_pyg_data
()
pyg_data
.
validate
()
pyg_data
.
validate
()
assert
torch
.
equal
(
pyg_data
.
edge_index
,
expected_edge_index
)
assert
torch
.
equal
(
pyg_data
.
edge_index
,
expected_edge_index
)
assert
torch
.
equal
(
pyg_data
.
x
,
expected_node_features
)
assert
torch
.
equal
(
pyg_data
.
x
,
expected_node_features
)
assert
torch
.
equal
(
pyg_data
.
y
,
expected_labels
)
assert
torch
.
equal
(
pyg_data
.
y
,
expected_labels
)
assert
pyg_data
.
batch_size
==
expected_batch_size
assert
torch
.
equal
(
pyg_data
.
n_id
,
expected_n_id
)
subgraph
=
test_minibatch
.
sampled_subgraphs
[
0
]
# Test with sampled_csc as None.
# Test with sampled_csc as None.
test_minibatch
=
gb
.
MiniBatch
(
test_minibatch
=
gb
.
MiniBatch
(
sampled_subgraphs
=
None
,
sampled_subgraphs
=
None
,
...
@@ -914,7 +901,7 @@ def test_to_pyg_data():
...
@@ -914,7 +901,7 @@ def test_to_pyg_data():
# Test with node_features as None.
# Test with node_features as None.
test_minibatch
=
gb
.
MiniBatch
(
test_minibatch
=
gb
.
MiniBatch
(
sampled_subgraphs
=
[
test_
subgraph
_a
],
sampled_subgraphs
=
[
subgraph
],
node_features
=
None
,
node_features
=
None
,
labels
=
expected_labels
,
labels
=
expected_labels
,
)
)
...
@@ -923,7 +910,7 @@ def test_to_pyg_data():
...
@@ -923,7 +910,7 @@ def test_to_pyg_data():
# Test with labels as None.
# Test with labels as None.
test_minibatch
=
gb
.
MiniBatch
(
test_minibatch
=
gb
.
MiniBatch
(
sampled_subgraphs
=
[
test_
subgraph
_a
],
sampled_subgraphs
=
[
subgraph
],
node_features
=
{
"feat"
:
expected_node_features
},
node_features
=
{
"feat"
:
expected_node_features
},
labels
=
None
,
labels
=
None
,
)
)
...
@@ -932,7 +919,7 @@ def test_to_pyg_data():
...
@@ -932,7 +919,7 @@ def test_to_pyg_data():
# Test with multiple features.
# Test with multiple features.
test_minibatch
=
gb
.
MiniBatch
(
test_minibatch
=
gb
.
MiniBatch
(
sampled_subgraphs
=
[
test_
subgraph
_a
],
sampled_subgraphs
=
[
subgraph
],
node_features
=
{
node_features
=
{
"feat"
:
expected_node_features
,
"feat"
:
expected_node_features
,
"extra_feat"
:
torch
.
tensor
([[
3
],
[
4
]]),
"extra_feat"
:
torch
.
tensor
([[
3
],
[
4
]]),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment