Unverified Commit 351c860a authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add check to NegativeSampler. (#6976)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 2e6ded06
...@@ -60,9 +60,12 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -60,9 +60,12 @@ class UniformNegativeSampler(NegativeSampler):
super().__init__(datapipe, negative_ratio) super().__init__(datapipe, negative_ratio)
self.graph = graph self.graph = graph
def _sample_with_etype(self, node_pairs, etype=None): def _sample_with_etype(self, node_pairs, etype=None, use_seeds=False):
return self.graph.sample_negative_edges_uniform( if not use_seeds:
etype, return self.graph.sample_negative_edges_uniform(
node_pairs, etype,
self.negative_ratio, node_pairs,
) self.negative_ratio,
)
else:
raise NotImplementedError("Not implemented yet.")
...@@ -55,19 +55,24 @@ class NegativeSampler(MiniBatchTransformer): ...@@ -55,19 +55,24 @@ class NegativeSampler(MiniBatchTransformer):
An instance of 'MiniBatch' encompasses both positive and negative An instance of 'MiniBatch' encompasses both positive and negative
samples. samples.
""" """
node_pairs = minibatch.node_pairs if minibatch.seeds is None:
assert node_pairs is not None node_pairs = minibatch.node_pairs
if isinstance(node_pairs, Mapping): assert node_pairs is not None
minibatch.negative_srcs, minibatch.negative_dsts = {}, {} if isinstance(node_pairs, Mapping):
for etype, pos_pairs in node_pairs.items(): minibatch.negative_srcs, minibatch.negative_dsts = {}, {}
self._collate( for etype, pos_pairs in node_pairs.items():
minibatch, self._sample_with_etype(pos_pairs, etype), etype self._collate(
) minibatch,
self._sample_with_etype(pos_pairs, etype),
etype,
)
else:
self._collate(minibatch, self._sample_with_etype(node_pairs))
else: else:
self._collate(minibatch, self._sample_with_etype(node_pairs)) raise NotImplementedError("Not implemented yet.")
return minibatch return minibatch
def _sample_with_etype(self, node_pairs, etype=None): def _sample_with_etype(self, node_pairs, etype=None, use_seeds=False):
"""Generate negative pairs for a given etype form positive pairs """Generate negative pairs for a given etype form positive pairs
for a given etype. for a given etype.
...@@ -102,14 +107,17 @@ class NegativeSampler(MiniBatchTransformer): ...@@ -102,14 +107,17 @@ class NegativeSampler(MiniBatchTransformer):
etype : str etype : str
Canonical edge type. Canonical edge type.
""" """
neg_src, neg_dst = neg_pairs if minibatch.seeds is None:
if neg_src is not None: neg_src, neg_dst = neg_pairs
neg_src = neg_src.view(-1, self.negative_ratio) if neg_src is not None:
if neg_dst is not None: neg_src = neg_src.view(-1, self.negative_ratio)
neg_dst = neg_dst.view(-1, self.negative_ratio) if neg_dst is not None:
if etype is not None: neg_dst = neg_dst.view(-1, self.negative_ratio)
minibatch.negative_srcs[etype] = neg_src if etype is not None:
minibatch.negative_dsts[etype] = neg_dst minibatch.negative_srcs[etype] = neg_src
minibatch.negative_dsts[etype] = neg_dst
else:
minibatch.negative_srcs = neg_src
minibatch.negative_dsts = neg_dst
else: else:
minibatch.negative_srcs = neg_src raise NotImplementedError("Not implemented yet.")
minibatch.negative_dsts = neg_dst
...@@ -31,6 +31,33 @@ def test_NegativeSampler_invoke(): ...@@ -31,6 +31,33 @@ def test_NegativeSampler_invoke():
next(iter(negative_sampler)) next(iter(negative_sampler))
def test_UniformNegativeSampler_seeds_invoke():
# Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="seeds"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
negative_ratio = 2
# Invoke UniformNegativeSampler via class constructor.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
graph,
negative_ratio,
)
with pytest.raises(NotImplementedError):
next(iter(negative_sampler))
# Invoke UniformNegativeSampler via functional form.
negative_sampler = item_sampler.sample_uniform_negative(
graph,
negative_ratio,
)
with pytest.raises(NotImplementedError):
next(iter(negative_sampler))
def test_UniformNegativeSampler_invoke(): def test_UniformNegativeSampler_invoke():
# Instantiate graph and required datapipes. # Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment