"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5f25818a0fd8747c46b27becc9c63dcfbbfeb638"
Unverified Commit 351c860a authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add check to NegativeSampler. (#6976)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 2e6ded06
...@@ -60,9 +60,12 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -60,9 +60,12 @@ class UniformNegativeSampler(NegativeSampler):
super().__init__(datapipe, negative_ratio) super().__init__(datapipe, negative_ratio)
self.graph = graph self.graph = graph
def _sample_with_etype(self, node_pairs, etype=None): def _sample_with_etype(self, node_pairs, etype=None, use_seeds=False):
if not use_seeds:
return self.graph.sample_negative_edges_uniform( return self.graph.sample_negative_edges_uniform(
etype, etype,
node_pairs, node_pairs,
self.negative_ratio, self.negative_ratio,
) )
else:
raise NotImplementedError("Not implemented yet.")
...@@ -55,19 +55,24 @@ class NegativeSampler(MiniBatchTransformer): ...@@ -55,19 +55,24 @@ class NegativeSampler(MiniBatchTransformer):
An instance of 'MiniBatch' encompasses both positive and negative An instance of 'MiniBatch' encompasses both positive and negative
samples. samples.
""" """
if minibatch.seeds is None:
node_pairs = minibatch.node_pairs node_pairs = minibatch.node_pairs
assert node_pairs is not None assert node_pairs is not None
if isinstance(node_pairs, Mapping): if isinstance(node_pairs, Mapping):
minibatch.negative_srcs, minibatch.negative_dsts = {}, {} minibatch.negative_srcs, minibatch.negative_dsts = {}, {}
for etype, pos_pairs in node_pairs.items(): for etype, pos_pairs in node_pairs.items():
self._collate( self._collate(
minibatch, self._sample_with_etype(pos_pairs, etype), etype minibatch,
self._sample_with_etype(pos_pairs, etype),
etype,
) )
else: else:
self._collate(minibatch, self._sample_with_etype(node_pairs)) self._collate(minibatch, self._sample_with_etype(node_pairs))
else:
raise NotImplementedError("Not implemented yet.")
return minibatch return minibatch
def _sample_with_etype(self, node_pairs, etype=None): def _sample_with_etype(self, node_pairs, etype=None, use_seeds=False):
"""Generate negative pairs for a given etype form positive pairs """Generate negative pairs for a given etype form positive pairs
for a given etype. for a given etype.
...@@ -102,6 +107,7 @@ class NegativeSampler(MiniBatchTransformer): ...@@ -102,6 +107,7 @@ class NegativeSampler(MiniBatchTransformer):
etype : str etype : str
Canonical edge type. Canonical edge type.
""" """
if minibatch.seeds is None:
neg_src, neg_dst = neg_pairs neg_src, neg_dst = neg_pairs
if neg_src is not None: if neg_src is not None:
neg_src = neg_src.view(-1, self.negative_ratio) neg_src = neg_src.view(-1, self.negative_ratio)
...@@ -113,3 +119,5 @@ class NegativeSampler(MiniBatchTransformer): ...@@ -113,3 +119,5 @@ class NegativeSampler(MiniBatchTransformer):
else: else:
minibatch.negative_srcs = neg_src minibatch.negative_srcs = neg_src
minibatch.negative_dsts = neg_dst minibatch.negative_dsts = neg_dst
else:
raise NotImplementedError("Not implemented yet.")
...@@ -31,6 +31,33 @@ def test_NegativeSampler_invoke(): ...@@ -31,6 +31,33 @@ def test_NegativeSampler_invoke():
next(iter(negative_sampler)) next(iter(negative_sampler))
def test_UniformNegativeSampler_seeds_invoke():
# Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="seeds"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
negative_ratio = 2
# Invoke UniformNegativeSampler via class constructor.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
graph,
negative_ratio,
)
with pytest.raises(NotImplementedError):
next(iter(negative_sampler))
# Invoke UniformNegativeSampler via functional form.
negative_sampler = item_sampler.sample_uniform_negative(
graph,
negative_ratio,
)
with pytest.raises(NotImplementedError):
next(iter(negative_sampler))
def test_UniformNegativeSampler_invoke(): def test_UniformNegativeSampler_invoke():
# Instantiate graph and required datapipes. # Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment