Unverified Commit e362f023 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graohbolt] Fix negative sampler hetero bugs (#6185)

parent f2d42266
......@@ -58,10 +58,18 @@ class NegativeSampler(Mapper):
"""
node_pairs = data.node_pair
if isinstance(node_pairs, Mapping):
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
data.label = {}
else:
data.negative_head, data.negative_tail = {}, {}
for etype, pos_pairs in node_pairs.items():
self._collate(
data, self._sample_with_etype(pos_pairs, etype), etype
)
if self.output_format == LinkPredictionEdgeFormat.HEAD_CONDITIONED:
data.negative_tail = None
if self.output_format == LinkPredictionEdgeFormat.TAIL_CONDITIONED:
data.negative_head = None
else:
self._collate(data, self._sample_with_etype(node_pairs))
return data
......@@ -101,7 +109,7 @@ class NegativeSampler(Mapper):
etype : (str, str, str)
Canonical edge type.
"""
pos_src, pos_dst = data.node_pair
pos_src, pos_dst = data.node_pair[etype] if etype else data.node_pair
neg_src, neg_dst = neg_pairs
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
pos_label = torch.ones_like(pos_src)
......
......@@ -142,3 +142,66 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
assert len(pos_dst) == batch_size
assert len(neg_dst) == batch_size
assert neg_dst.numel() == batch_size * negative_ratio
def get_hetero_graph():
# COO graph:
# [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
# [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1}
etypes = {("n1", "e1", "n2"): 0, ("n2", "e2", "n1"): 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
return gb.from_csc(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
)
def to_link_block(data):
block = gb.LinkPredictionBlock(node_pair=data)
return block
@pytest.mark.parametrize(
"format",
[
gb.LinkPredictionEdgeFormat.INDEPENDENT,
gb.LinkPredictionEdgeFormat.CONDITIONED,
gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
],
)
def test_NegativeSampler_Hetero_Data(format):
graph = get_hetero_graph()
itemset = gb.ItemSetDict(
{
("n1", "e1", "n2"): gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1]),
torch.LongTensor([0, 2, 0, 1]),
)
),
("n2", "e2", "n1"): gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1, 2, 2]),
torch.LongTensor([0, 1, 1, 0, 0, 1]),
)
),
}
)
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
data_block_converter = Mapper(minibatch_dp, to_link_block)
negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph
)
assert len(list(negative_dp)) == 5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment