"src/array/vscode:/vscode.git/clone" did not exist on "7bab1365e24993d549bd62f8fe8475ad9767b992"
Unverified Commit a7f1ab9f authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify `labels` dtype. (#7200)

parent 89e49439
...@@ -86,20 +86,9 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -86,20 +86,9 @@ class UniformNegativeSampler(NegativeSampler):
# Construct labels for all node pairs. # Construct labels for all node pairs.
pos_num = node_pairs.shape[0] pos_num = node_pairs.shape[0]
neg_num = seeds.shape[0] - pos_num neg_num = seeds.shape[0] - pos_num
labels = torch.cat( labels = torch.empty(pos_num + neg_num, device=seeds.device)
( labels[:pos_num] = 1
torch.ones( labels[pos_num:] = 0
pos_num,
dtype=torch.bool,
device=seeds.device,
),
torch.zeros(
neg_num,
dtype=torch.bool,
device=seeds.device,
),
),
)
return seeds, labels, indexes return seeds, labels, indexes
else: else:
return self.graph.sample_negative_edges_uniform( return self.graph.sample_negative_edges_uniform(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment