Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
72b3e078
"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "7b8a6db7f450e70a1e0fb07e07b30dda6a7e6e1c"
Unverified
Commit
72b3e078
authored
Oct 19, 2023
by
peizhou001
Committed by
GitHub
Oct 19, 2023
Browse files
[Graphbolt] Remove redundant data in to_dgl (#6466)
parent
625f8a6b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
11 deletions
+41
-11
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+6
-2
tests/python/pytorch/graphbolt/impl/test_minibatch.py
tests/python/pytorch/graphbolt/impl/test_minibatch.py
+29
-3
tests/python/pytorch/graphbolt/test_integration.py
tests/python/pytorch/graphbolt/test_integration.py
+6
-6
No files found.
python/dgl/graphbolt/minibatch.py
View file @
72b3e078
...
@@ -314,12 +314,16 @@ class MiniBatch:
...
@@ -314,12 +314,16 @@ class MiniBatch:
"""
"""
minibatch
=
DGLMiniBatch
(
minibatch
=
DGLMiniBatch
(
blocks
=
self
.
_to_dgl_blocks
(),
blocks
=
self
.
_to_dgl_blocks
(),
input_nodes
=
self
.
input_nodes
,
output_nodes
=
self
.
seed_nodes
,
node_features
=
self
.
node_features
,
node_features
=
self
.
node_features
,
edge_features
=
self
.
edge_features
,
edge_features
=
self
.
edge_features
,
labels
=
self
.
labels
,
labels
=
self
.
labels
,
)
)
# Need input nodes to fetch feature.
if
self
.
node_features
is
None
:
minibatch
.
input_nodes
=
self
.
input_nodes
# Need output nodes to fetch label.
if
self
.
labels
is
None
:
minibatch
.
output_nodes
=
self
.
seed_nodes
assert
(
assert
(
minibatch
.
blocks
is
not
None
minibatch
.
blocks
is
not
None
),
"Sampled subgraphs for computation are missing."
),
"Sampled subgraphs for computation are missing."
...
...
tests/python/pytorch/graphbolt/impl/test_minibatch.py
View file @
72b3e078
...
@@ -50,6 +50,7 @@ def create_homo_minibatch():
...
@@ -50,6 +50,7 @@ def create_homo_minibatch():
sampled_subgraphs
=
subgraphs
,
sampled_subgraphs
=
subgraphs
,
node_features
=
node_features
,
node_features
=
node_features
,
edge_features
=
edge_features
,
edge_features
=
edge_features
,
input_nodes
=
torch
.
tensor
([
10
,
11
,
12
,
13
]),
)
)
...
@@ -103,6 +104,10 @@ def create_hetero_minibatch():
...
@@ -103,6 +104,10 @@ def create_hetero_minibatch():
sampled_subgraphs
=
subgraphs
,
sampled_subgraphs
=
subgraphs
,
node_features
=
node_features
,
node_features
=
node_features
,
edge_features
=
edge_features
,
edge_features
=
edge_features
,
input_nodes
=
{
"A"
:
torch
.
tensor
([
5
,
7
,
9
,
11
]),
"B"
:
torch
.
tensor
([
10
,
11
,
12
]),
},
)
)
...
@@ -286,7 +291,7 @@ def test_dgl_minibatch_representation():
...
@@ -286,7 +291,7 @@ def test_dgl_minibatch_representation():
node_features={'x': tensor([7, 6, 2, 2])},
node_features={'x': tensor([7, 6, 2, 2])},
negative_node_pairs=(tensor([0, 1, 2]), tensor([6, 0, 0])),
negative_node_pairs=(tensor([0, 1, 2]), tensor([6, 0, 0])),
labels=tensor([0., 1., 2.]),
labels=tensor([0., 1., 2.]),
input_nodes=
tensor([8, 1, 6, 5, 9, 0, 2, 4])
,
input_nodes=
None
,
edge_features=[{'x': tensor([[8],
edge_features=[{'x': tensor([[8],
[1],
[1],
[6]])},
[6]])},
...
@@ -354,6 +359,25 @@ def check_dgl_blocks_homo(minibatch, blocks):
...
@@ -354,6 +359,25 @@ def check_dgl_blocks_homo(minibatch, blocks):
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
dgl
.
NID
],
original_row_node_ids
[
0
])
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
dgl
.
NID
],
original_row_node_ids
[
0
])
def
test_to_dgl_node_classification_without_feature
():
# Arrange
minibatch
=
create_homo_minibatch
()
minibatch
.
node_features
=
None
minibatch
.
labels
=
None
minibatch
.
seed_nodes
=
torch
.
tensor
([
10
,
15
])
# Act
dgl_minibatch
=
minibatch
.
to_dgl
()
# Assert
assert
len
(
dgl_minibatch
.
blocks
)
==
2
assert
dgl_minibatch
.
node_features
is
None
assert
minibatch
.
edge_features
is
dgl_minibatch
.
edge_features
assert
dgl_minibatch
.
labels
is
None
assert
minibatch
.
input_nodes
is
dgl_minibatch
.
input_nodes
assert
minibatch
.
seed_nodes
is
dgl_minibatch
.
output_nodes
check_dgl_blocks_homo
(
minibatch
,
dgl_minibatch
.
blocks
)
def
test_to_dgl_node_classification_homo
():
def
test_to_dgl_node_classification_homo
():
# Arrange
# Arrange
minibatch
=
create_homo_minibatch
()
minibatch
=
create_homo_minibatch
()
...
@@ -367,7 +391,8 @@ def test_to_dgl_node_classification_homo():
...
@@ -367,7 +391,8 @@ def test_to_dgl_node_classification_homo():
assert
minibatch
.
node_features
is
dgl_minibatch
.
node_features
assert
minibatch
.
node_features
is
dgl_minibatch
.
node_features
assert
minibatch
.
edge_features
is
dgl_minibatch
.
edge_features
assert
minibatch
.
edge_features
is
dgl_minibatch
.
edge_features
assert
minibatch
.
labels
is
dgl_minibatch
.
labels
assert
minibatch
.
labels
is
dgl_minibatch
.
labels
assert
minibatch
.
seed_nodes
is
dgl_minibatch
.
output_nodes
assert
dgl_minibatch
.
input_nodes
is
None
assert
dgl_minibatch
.
output_nodes
is
None
check_dgl_blocks_homo
(
minibatch
,
dgl_minibatch
.
blocks
)
check_dgl_blocks_homo
(
minibatch
,
dgl_minibatch
.
blocks
)
...
@@ -382,7 +407,8 @@ def test_to_dgl_node_classification_hetero():
...
@@ -382,7 +407,8 @@ def test_to_dgl_node_classification_hetero():
assert
minibatch
.
node_features
is
dgl_minibatch
.
node_features
assert
minibatch
.
node_features
is
dgl_minibatch
.
node_features
assert
minibatch
.
edge_features
is
dgl_minibatch
.
edge_features
assert
minibatch
.
edge_features
is
dgl_minibatch
.
edge_features
assert
minibatch
.
labels
is
dgl_minibatch
.
labels
assert
minibatch
.
labels
is
dgl_minibatch
.
labels
assert
minibatch
.
seed_nodes
is
dgl_minibatch
.
output_nodes
assert
dgl_minibatch
.
input_nodes
is
None
assert
dgl_minibatch
.
output_nodes
is
None
check_dgl_blocks_hetero
(
minibatch
,
dgl_minibatch
.
blocks
)
check_dgl_blocks_hetero
(
minibatch
,
dgl_minibatch
.
blocks
)
...
...
tests/python/pytorch/graphbolt/test_integration.py
View file @
72b3e078
...
@@ -71,7 +71,7 @@ def test_integration_link_prediction():
...
@@ -71,7 +71,7 @@ def test_integration_link_prediction():
[0.5503, 0.8223]])},
[0.5503, 0.8223]])},
negative_node_pairs=(tensor([0, 1, 1, 1]), tensor([0, 3, 4, 5])),
negative_node_pairs=(tensor([0, 1, 1, 1]), tensor([0, 3, 4, 5])),
labels=None,
labels=None,
input_nodes=
tensor([5, 3, 1, 2, 0, 4])
,
input_nodes=
None
,
edge_features=[{},
edge_features=[{},
{}],
{}],
blocks=[Block(num_src_nodes=6,
blocks=[Block(num_src_nodes=6,
...
@@ -92,7 +92,7 @@ def test_integration_link_prediction():
...
@@ -92,7 +92,7 @@ def test_integration_link_prediction():
[0.6172, 0.7865]])},
[0.6172, 0.7865]])},
negative_node_pairs=(tensor([0, 1, 1, 2]), tensor([1, 3, 4, 1])),
negative_node_pairs=(tensor([0, 1, 1, 2]), tensor([1, 3, 4, 1])),
labels=None,
labels=None,
input_nodes=
tensor([3, 4, 0, 5, 1])
,
input_nodes=
None
,
edge_features=[{},
edge_features=[{},
{}],
{}],
blocks=[Block(num_src_nodes=5,
blocks=[Block(num_src_nodes=5,
...
@@ -112,7 +112,7 @@ def test_integration_link_prediction():
...
@@ -112,7 +112,7 @@ def test_integration_link_prediction():
[0.9634, 0.2294]])},
[0.9634, 0.2294]])},
negative_node_pairs=(tensor([0, 1]), tensor([1, 2])),
negative_node_pairs=(tensor([0, 1]), tensor([1, 2])),
labels=None,
labels=None,
input_nodes=
tensor([5, 4, 3, 0])
,
input_nodes=
None
,
edge_features=[{},
edge_features=[{},
{}],
{}],
blocks=[Block(num_src_nodes=4,
blocks=[Block(num_src_nodes=4,
...
@@ -193,7 +193,7 @@ def test_integration_node_classification():
...
@@ -193,7 +193,7 @@ def test_integration_node_classification():
[0.9634, 0.2294]])},
[0.9634, 0.2294]])},
negative_node_pairs=None,
negative_node_pairs=None,
labels=None,
labels=None,
input_nodes=
tensor([5, 3, 1, 2, 4, 0])
,
input_nodes=
None
,
edge_features=[{},
edge_features=[{},
{}],
{}],
blocks=[Block(num_src_nodes=6,
blocks=[Block(num_src_nodes=6,
...
@@ -212,7 +212,7 @@ def test_integration_node_classification():
...
@@ -212,7 +212,7 @@ def test_integration_node_classification():
[0.9634, 0.2294]])},
[0.9634, 0.2294]])},
negative_node_pairs=None,
negative_node_pairs=None,
labels=None,
labels=None,
input_nodes=
tensor([3, 4, 0])
,
input_nodes=
None
,
edge_features=[{},
edge_features=[{},
{}],
{}],
blocks=[Block(num_src_nodes=3,
blocks=[Block(num_src_nodes=3,
...
@@ -231,7 +231,7 @@ def test_integration_node_classification():
...
@@ -231,7 +231,7 @@ def test_integration_node_classification():
[0.9634, 0.2294]])},
[0.9634, 0.2294]])},
negative_node_pairs=None,
negative_node_pairs=None,
labels=None,
labels=None,
input_nodes=
tensor([5, 4, 0])
,
input_nodes=
None
,
edge_features=[{},
edge_features=[{},
{}],
{}],
blocks=[Block(num_src_nodes=3,
blocks=[Block(num_src_nodes=3,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment