Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
67cd09d7
Unverified
Commit
67cd09d7
authored
Dec 20, 2023
by
Ramon Zhou
Committed by
GitHub
Dec 20, 2023
Browse files
[GraphBolt] Support heterogeneous graph in `node_pairs_with_labels` (#6787)
parent
541f2ba4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
14 deletions
+81
-14
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+39
-14
tests/python/pytorch/graphbolt/impl/test_minibatch.py
tests/python/pytorch/graphbolt/impl/test_minibatch.py
+42
-0
No files found.
python/dgl/graphbolt/minibatch.py
View file @
67cd09d7
...
...
@@ -381,20 +381,42 @@ class MiniBatch:
@
property
def
node_pairs_with_labels
(
self
):
"""Get a node pair tensor and a label tensor from MiniBatch. They are
used for evaluating or computing loss. It will return
`(node_pairs, labels)` as result.
used for evaluating or computing loss. For homogeneous graph, it will
return `(node_pairs, labels)` as result; for heterogeneous graph, the
`node_pairs` and `labels` will both be a dict with etype as the key.
- If it's a link prediction task, `node_pairs` will contain both
negative and positive node pairs and `labels` will consist of 0 and 1,
indicating whether the corresponding node pair is negative or positive.
- If it's an edge classification task, this function will directly
return `compacted_node_pairs` and corresponding `labels`.
return `compacted_node_pairs` for each etype and the corresponding
`labels`.
- Otherwise it will return None.
"""
if
self
.
labels
is
None
:
# Link prediction.
positive_node_pairs
=
self
.
positive_node_pairs
negative_node_pairs
=
self
.
negative_node_pairs
if
positive_node_pairs
is
None
or
negative_node_pairs
is
None
:
return
None
if
isinstance
(
positive_node_pairs
,
Dict
):
# Heterogeneous graph.
node_pairs_by_etype
=
{}
labels_by_etype
=
{}
for
etype
in
positive_node_pairs
:
pos_src
,
pos_dst
=
positive_node_pairs
[
etype
]
neg_src
,
neg_dst
=
negative_node_pairs
[
etype
]
node_pairs_by_etype
[
etype
]
=
(
torch
.
cat
((
pos_src
,
neg_src
),
dim
=
0
),
torch
.
cat
((
pos_dst
,
neg_dst
),
dim
=
0
),
)
pos_label
=
torch
.
ones_like
(
pos_src
)
neg_label
=
torch
.
zeros_like
(
neg_src
)
labels_by_etype
[
etype
]
=
torch
.
cat
(
[
pos_label
,
neg_label
],
dim
=
0
)
return
(
node_pairs_by_etype
,
labels_by_etype
)
else
:
# Homogeneous graph.
pos_src
,
pos_dst
=
positive_node_pairs
neg_src
,
neg_dst
=
negative_node_pairs
node_pairs
=
(
...
...
@@ -405,8 +427,11 @@ class MiniBatch:
neg_label
=
torch
.
zeros_like
(
neg_src
)
labels
=
torch
.
cat
([
pos_label
,
neg_label
],
dim
=
0
)
return
(
node_pairs
,
labels
.
float
())
else
:
elif
self
.
compacted_node_pairs
is
not
None
:
# Edge classification.
return
(
self
.
compacted_node_pairs
,
self
.
labels
)
else
:
return
None
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
"""Copy `MiniBatch` to the specified device using reflection."""
...
...
tests/python/pytorch/graphbolt/impl/test_minibatch.py
View file @
67cd09d7
...
...
@@ -786,6 +786,48 @@ def test_dgl_link_predication_hetero(mode):
minibatch
.
negative_node_pairs
[
etype
][
1
],
minibatch
.
compacted_negative_dsts
[
etype
].
view
(
-
1
),
)
node_pairs
,
labels
=
minibatch
.
node_pairs_with_labels
if
mode
==
"neg_src"
:
expect_node_pairs
=
{
"A:r:B"
:
(
torch
.
tensor
([
1
,
1
,
2
,
0
,
1
,
2
]),
torch
.
tensor
([
1
,
0
,
1
,
1
,
0
,
0
]),
),
"B:rr:A"
:
(
torch
.
tensor
([
0
,
1
,
1
,
2
,
0
,
2
]),
torch
.
tensor
([
1
,
0
,
1
,
1
,
0
,
0
]),
),
}
elif
mode
==
"neg_dst"
:
expect_node_pairs
=
{
"A:r:B"
:
(
torch
.
tensor
([
1
,
1
,
1
,
1
,
1
,
1
]),
torch
.
tensor
([
1
,
0
,
1
,
3
,
2
,
1
]),
),
"B:rr:A"
:
(
torch
.
tensor
([
0
,
1
,
0
,
0
,
1
,
1
]),
torch
.
tensor
([
1
,
0
,
2
,
1
,
3
,
1
]),
),
}
else
:
expect_node_pairs
=
{
"A:r:B"
:
(
torch
.
tensor
([
1
,
1
,
2
,
0
,
1
,
2
]),
torch
.
tensor
([
1
,
0
,
1
,
3
,
2
,
1
]),
),
"B:rr:A"
:
(
torch
.
tensor
([
0
,
1
,
1
,
2
,
0
,
2
]),
torch
.
tensor
([
1
,
0
,
2
,
1
,
3
,
1
]),
),
}
expect_labels
=
{
"A:r:B"
:
torch
.
tensor
([
1
,
1
,
0
,
0
,
0
,
0
]),
"B:rr:A"
:
torch
.
tensor
([
1
,
1
,
0
,
0
,
0
,
0
]),
}
for
etype
in
node_pairs
:
assert
torch
.
equal
(
node_pairs
[
etype
][
0
],
expect_node_pairs
[
etype
][
0
])
assert
torch
.
equal
(
node_pairs
[
etype
][
1
],
expect_node_pairs
[
etype
][
1
])
assert
torch
.
equal
(
labels
[
etype
],
expect_labels
[
etype
])
def
create_homo_minibatch_csc_format
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment