Unverified Commit 351c860a authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add check to NegativeSampler. (#6976)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 2e6ded06
......@@ -60,9 +60,12 @@ class UniformNegativeSampler(NegativeSampler):
super().__init__(datapipe, negative_ratio)
self.graph = graph
def _sample_with_etype(self, node_pairs, etype=None):
return self.graph.sample_negative_edges_uniform(
etype,
node_pairs,
self.negative_ratio,
)
def _sample_with_etype(self, node_pairs, etype=None, use_seeds=False):
if not use_seeds:
return self.graph.sample_negative_edges_uniform(
etype,
node_pairs,
self.negative_ratio,
)
else:
raise NotImplementedError("Not implemented yet.")
......@@ -55,19 +55,24 @@ class NegativeSampler(MiniBatchTransformer):
An instance of 'MiniBatch' encompasses both positive and negative
samples.
"""
node_pairs = minibatch.node_pairs
assert node_pairs is not None
if isinstance(node_pairs, Mapping):
minibatch.negative_srcs, minibatch.negative_dsts = {}, {}
for etype, pos_pairs in node_pairs.items():
self._collate(
minibatch, self._sample_with_etype(pos_pairs, etype), etype
)
if minibatch.seeds is None:
node_pairs = minibatch.node_pairs
assert node_pairs is not None
if isinstance(node_pairs, Mapping):
minibatch.negative_srcs, minibatch.negative_dsts = {}, {}
for etype, pos_pairs in node_pairs.items():
self._collate(
minibatch,
self._sample_with_etype(pos_pairs, etype),
etype,
)
else:
self._collate(minibatch, self._sample_with_etype(node_pairs))
else:
self._collate(minibatch, self._sample_with_etype(node_pairs))
raise NotImplementedError("Not implemented yet.")
return minibatch
def _sample_with_etype(self, node_pairs, etype=None):
def _sample_with_etype(self, node_pairs, etype=None, use_seeds=False):
"""Generate negative pairs for a given etype form positive pairs
for a given etype.
......@@ -102,14 +107,17 @@ class NegativeSampler(MiniBatchTransformer):
etype : str
Canonical edge type.
"""
neg_src, neg_dst = neg_pairs
if neg_src is not None:
neg_src = neg_src.view(-1, self.negative_ratio)
if neg_dst is not None:
neg_dst = neg_dst.view(-1, self.negative_ratio)
if etype is not None:
minibatch.negative_srcs[etype] = neg_src
minibatch.negative_dsts[etype] = neg_dst
if minibatch.seeds is None:
neg_src, neg_dst = neg_pairs
if neg_src is not None:
neg_src = neg_src.view(-1, self.negative_ratio)
if neg_dst is not None:
neg_dst = neg_dst.view(-1, self.negative_ratio)
if etype is not None:
minibatch.negative_srcs[etype] = neg_src
minibatch.negative_dsts[etype] = neg_dst
else:
minibatch.negative_srcs = neg_src
minibatch.negative_dsts = neg_dst
else:
minibatch.negative_srcs = neg_src
minibatch.negative_dsts = neg_dst
raise NotImplementedError("Not implemented yet.")
......@@ -31,6 +31,33 @@ def test_NegativeSampler_invoke():
next(iter(negative_sampler))
def test_UniformNegativeSampler_seeds_invoke():
# Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="seeds"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
negative_ratio = 2
# Invoke UniformNegativeSampler via class constructor.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
graph,
negative_ratio,
)
with pytest.raises(NotImplementedError):
next(iter(negative_sampler))
# Invoke UniformNegativeSampler via functional form.
negative_sampler = item_sampler.sample_uniform_negative(
graph,
negative_ratio,
)
with pytest.raises(NotImplementedError):
next(iter(negative_sampler))
def test_UniformNegativeSampler_invoke():
# Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment