Unverified Commit a7f1ab9f authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify `labels` dtype. (#7200)

parent 89e49439
......@@ -86,20 +86,9 @@ class UniformNegativeSampler(NegativeSampler):
# Construct labels for all node pairs.
pos_num = node_pairs.shape[0]
neg_num = seeds.shape[0] - pos_num
labels = torch.cat(
(
torch.ones(
pos_num,
dtype=torch.bool,
device=seeds.device,
),
torch.zeros(
neg_num,
dtype=torch.bool,
device=seeds.device,
),
),
)
labels = torch.empty(pos_num + neg_num, device=seeds.device)
labels[:pos_num] = 1
labels[pos_num:] = 0
return seeds, labels, indexes
else:
return self.graph.sample_negative_edges_uniform(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment