Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a7f1ab9f
Unverified
Commit
a7f1ab9f
authored
Mar 07, 2024
by
yxy235
Committed by
GitHub
Mar 07, 2024
Browse files
[GraphBolt] Modify `labels` dtype. (#7200)
parent
89e49439
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
14 deletions
+3
-14
python/dgl/graphbolt/impl/uniform_negative_sampler.py
python/dgl/graphbolt/impl/uniform_negative_sampler.py
+3
-14
No files found.
python/dgl/graphbolt/impl/uniform_negative_sampler.py
View file @
a7f1ab9f
...
...
@@ -86,20 +86,9 @@ class UniformNegativeSampler(NegativeSampler):
# Construct labels for all node pairs.
pos_num
=
node_pairs
.
shape
[
0
]
neg_num
=
seeds
.
shape
[
0
]
-
pos_num
labels
=
torch
.
cat
(
(
torch
.
ones
(
pos_num
,
dtype
=
torch
.
bool
,
device
=
seeds
.
device
,
),
torch
.
zeros
(
neg_num
,
dtype
=
torch
.
bool
,
device
=
seeds
.
device
,
),
),
)
labels
=
torch
.
empty
(
pos_num
+
neg_num
,
device
=
seeds
.
device
)
labels
[:
pos_num
]
=
1
labels
[
pos_num
:]
=
0
return
seeds
,
labels
,
indexes
else
:
return
self
.
graph
.
sample_negative_edges_uniform
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment