Unverified Commit e7c3454f authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Fix `NegativeSampler` seeds support. (#7043)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent af870387
...@@ -82,7 +82,9 @@ class NegativeSampler(MiniBatchTransformer): ...@@ -82,7 +82,9 @@ class NegativeSampler(MiniBatchTransformer):
minibatch.seeds[etype], minibatch.seeds[etype],
minibatch.labels[etype], minibatch.labels[etype],
minibatch.indexes[etype], minibatch.indexes[etype],
) = self._sample_with_etype(pos_pairs, use_seeds=True) ) = self._sample_with_etype(
pos_pairs, etype, use_seeds=True
)
else: else:
( (
minibatch.seeds, minibatch.seeds,
......
...@@ -293,7 +293,23 @@ def test_NegativeSampler_Hetero_Data(): ...@@ -293,7 +293,23 @@ def test_NegativeSampler_Hetero_Data():
), ),
} }
) )
batch_size = 2
item_sampler = gb.ItemSampler(itemset, batch_size=2) negative_ratio = 1
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1) item_sampler = gb.ItemSampler(itemset, batch_size=batch_size)
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, negative_ratio)
assert len(list(negative_dp)) == 5 assert len(list(negative_dp)) == 5
# Perform negative sampling.
expected_neg_src = [
{"n1:e1:n2": torch.tensor([0, 0])},
{"n1:e1:n2": torch.tensor([1, 1])},
{"n2:e2:n1": torch.tensor([0, 0])},
{"n2:e2:n1": torch.tensor([1, 1])},
{"n2:e2:n1": torch.tensor([2, 2])},
]
for i, data in enumerate(negative_dp):
# Check negative seeds value.
for etype, seeds_data in data.seeds.items():
neg_src = seeds_data[batch_size:, 0]
neg_dst = seeds_data[batch_size:, 1]
assert torch.equal(expected_neg_src[i][etype], neg_src)
assert (neg_dst < 3).all(), neg_dst
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment