Unverified Commit e7c3454f authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Fix `NegativeSampler` seeds support. (#7043)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent af870387
......@@ -82,7 +82,9 @@ class NegativeSampler(MiniBatchTransformer):
minibatch.seeds[etype],
minibatch.labels[etype],
minibatch.indexes[etype],
) = self._sample_with_etype(pos_pairs, use_seeds=True)
) = self._sample_with_etype(
pos_pairs, etype, use_seeds=True
)
else:
(
minibatch.seeds,
......
......@@ -293,7 +293,23 @@ def test_NegativeSampler_Hetero_Data():
),
}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1)
batch_size = 2
negative_ratio = 1
item_sampler = gb.ItemSampler(itemset, batch_size=batch_size)
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, negative_ratio)
assert len(list(negative_dp)) == 5
# Perform negative sampling.
expected_neg_src = [
{"n1:e1:n2": torch.tensor([0, 0])},
{"n1:e1:n2": torch.tensor([1, 1])},
{"n2:e2:n1": torch.tensor([0, 0])},
{"n2:e2:n1": torch.tensor([1, 1])},
{"n2:e2:n1": torch.tensor([2, 2])},
]
for i, data in enumerate(negative_dp):
# Check negative seeds value.
for etype, seeds_data in data.seeds.items():
neg_src = seeds_data[batch_size:, 0]
neg_dst = seeds_data[batch_size:, 1]
assert torch.equal(expected_neg_src[i][etype], neg_src)
assert (neg_dst < 3).all(), neg_dst
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment