Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e26d2064
Unverified
Commit
e26d2064
authored
Jul 13, 2022
by
Chang Liu
Committed by
GitHub
Jul 14, 2022
Browse files
[Example][Bugfix] Fix link pred example in graphsage (#4255)
parent
1c9528f5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
examples/pytorch/graphsage/link_pred.py
examples/pytorch/graphsage/link_pred.py
+4
-4
No files found.
examples/pytorch/graphsage/link_pred.py
View file @
e26d2064
...
@@ -106,7 +106,7 @@ def compute_mrr(model, node_emb, src, dst, neg_dst, device, batch_size=500):
...
@@ -106,7 +106,7 @@ def compute_mrr(model, node_emb, src, dst, neg_dst, device, batch_size=500):
h_src
=
node_emb
[
src
[
start
:
end
]][:,
None
,
:].
to
(
device
)
h_src
=
node_emb
[
src
[
start
:
end
]][:,
None
,
:].
to
(
device
)
h_dst
=
node_emb
[
all_dst
.
view
(
-
1
)].
view
(
*
all_dst
.
shape
,
-
1
).
to
(
device
)
h_dst
=
node_emb
[
all_dst
.
view
(
-
1
)].
view
(
*
all_dst
.
shape
,
-
1
).
to
(
device
)
pred
=
model
.
predict
(
h_src
,
h_dst
).
squeeze
(
-
1
)
pred
=
model
.
predict
(
h_src
,
h_dst
).
squeeze
(
-
1
)
relevance
=
torch
.
zeros
(
*
pred
.
shape
,
dtype
=
torch
.
bool
)
relevance
=
torch
.
zeros
(
*
pred
.
shape
,
dtype
=
torch
.
bool
)
.
to
(
pred
.
device
)
relevance
[:,
0
]
=
True
relevance
[:,
0
]
=
True
rr
[
start
:
end
]
=
MF
.
retrieval_reciprocal_rank
(
pred
,
relevance
)
rr
[
start
:
end
]
=
MF
.
retrieval_reciprocal_rank
(
pred
,
relevance
)
return
rr
.
mean
()
return
rr
.
mean
()
...
@@ -117,9 +117,9 @@ def evaluate(model, edge_split, device, num_workers):
...
@@ -117,9 +117,9 @@ def evaluate(model, edge_split, device, num_workers):
node_emb
=
model
.
inference
(
graph
,
device
,
4096
,
num_workers
,
'cpu'
)
node_emb
=
model
.
inference
(
graph
,
device
,
4096
,
num_workers
,
'cpu'
)
results
=
[]
results
=
[]
for
split
in
[
'valid'
,
'test'
]:
for
split
in
[
'valid'
,
'test'
]:
src
=
edge_split
[
split
][
'source_node'
].
to
(
device
)
src
=
edge_split
[
split
][
'source_node'
].
to
(
node_emb
.
device
)
dst
=
edge_split
[
split
][
'target_node'
].
to
(
device
)
dst
=
edge_split
[
split
][
'target_node'
].
to
(
node_emb
.
device
)
neg_dst
=
edge_split
[
split
][
'target_node_neg'
].
to
(
device
)
neg_dst
=
edge_split
[
split
][
'target_node_neg'
].
to
(
node_emb
.
device
)
results
.
append
(
compute_mrr
(
model
,
node_emb
,
src
,
dst
,
neg_dst
,
device
))
results
.
append
(
compute_mrr
(
model
,
node_emb
,
src
,
dst
,
neg_dst
,
device
))
return
results
return
results
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment