Unverified Commit 44f4b0e2 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt]Adapt negative sampler (#6165)

parent ceb25724
......@@ -6,6 +6,7 @@ import torch
from torchdata.datapipes.iter import Mapper
from .data_format import LinkPredictionEdgeFormat
from .link_prediction_block import LinkPredictionBlock
class NegativeSampler(Mapper):
......@@ -50,21 +51,20 @@ class NegativeSampler(Mapper):
Returns
-------
Tuple[Tensor] or Dict[etype, Tuple[Tensor]]
A collection of edges or a dictionary that maps etypes to edges,
which includes both positive and negative samples.
LinkPredictionBlock
An instance of 'LinkPredictionBlock' encompasses both positive and
negative samples.
"""
data = LinkPredictionBlock(node_pair=node_pairs)
if isinstance(node_pairs, Mapping):
return {
etype: self._collate(
pos_pairs, self._sample_with_etype(pos_pairs, etype)
for etype, pos_pairs in node_pairs.items():
self._collate(
data, self._sample_with_etype(pos_pairs, etype), etype
)
for etype, pos_pairs in node_pairs.items()
}
else:
return self._collate(
node_pairs, self._sample_with_etype(node_pairs, None)
)
self._collate(data, self._sample_with_etype(node_pairs))
return data
def _sample_with_etype(self, node_pairs, etype=None):
"""Generate negative pairs for a given etype form positive pairs
......@@ -86,49 +86,56 @@ class NegativeSampler(Mapper):
"""
raise NotImplementedError
def _collate(self, pos_pairs, neg_pairs):
"""Collates positive and negative samples.
def _collate(self, data, neg_pairs, etype=None):
"""Collates positive and negative samples into data.
Parameters
----------
pos_pairs : Tuple[Tensor]
A tuple of tensors represents source-destination node pairs of
positive edges, where positive means the edge must exist in
the graph.
data : LinkPredictionBlock
The input data, which contains positive node pairs, will be filled
with negative information in this function.
neg_pairs : Tuple[Tensor]
A tuple of tensors represents source-destination node pairs of
negative edges, where negative means the edge may not exist in
the graph.
Returns
-------
Tuple[Tensor]
A mixed collection of positive and negative node pairs.
etype : (str, str, str)
Canonical edge type.
"""
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
pos_src, pos_dst = pos_pairs
pos_src, pos_dst = data.node_pair
neg_src, neg_dst = neg_pairs
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
src = torch.cat([pos_src, neg_src])
dst = torch.cat([pos_dst, neg_dst])
label = torch.cat([pos_label, neg_label])
return (src, dst, label)
elif self.output_format == LinkPredictionEdgeFormat.CONDITIONED:
pos_src, pos_dst = pos_pairs
neg_src, neg_dst = neg_pairs
if etype:
data.node_pair[etype] = (src, dst)
data.label[etype] = label
else:
data.node_pair = (src, dst)
data.label = label
else:
if self.output_format == LinkPredictionEdgeFormat.CONDITIONED:
neg_src = neg_src.view(-1, self.negative_ratio)
neg_dst = neg_dst.view(-1, self.negative_ratio)
return (pos_src, pos_dst, neg_src, neg_dst)
elif self.output_format == LinkPredictionEdgeFormat.HEAD_CONDITIONED:
pos_src, pos_dst = pos_pairs
neg_src, _ = neg_pairs
elif (
self.output_format == LinkPredictionEdgeFormat.HEAD_CONDITIONED
):
neg_src = neg_src.view(-1, self.negative_ratio)
return (pos_src, pos_dst, neg_src)
elif self.output_format == LinkPredictionEdgeFormat.TAIL_CONDITIONED:
pos_src, pos_dst = pos_pairs
_, neg_dst = neg_pairs
neg_dst = None
elif (
self.output_format == LinkPredictionEdgeFormat.TAIL_CONDITIONED
):
neg_dst = neg_dst.view(-1, self.negative_ratio)
return (pos_src, pos_dst, neg_dst)
neg_src = None
else:
raise TypeError(
f"Unsupported output format {self.output_format}."
)
if etype:
data.negative_head[etype] = neg_src
data.negative_tail[etype] = neg_dst
else:
raise ValueError("Unsupported output format.")
data.negative_head = neg_src
data.negative_tail = neg_dst
......@@ -26,7 +26,8 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
)
# Perform Negative sampling.
for data in negative_sampler:
src, dst, label = data
src, dst = data.node_pair
label = data.label
# Assertation
assert len(src) == batch_size * (negative_ratio + 1)
assert len(dst) == batch_size * (negative_ratio + 1)
......@@ -57,7 +58,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
)
# Perform Negative sampling.
for data in negative_sampler:
pos_src, pos_dst, neg_src, neg_dst = data
pos_src, pos_dst = data.node_pair
neg_src, neg_dst = data.negative_head, data.negative_tail
# Assertation
assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size
......@@ -91,7 +93,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
)
# Perform Negative sampling.
for data in negative_sampler:
pos_src, pos_dst, neg_src = data
pos_src, pos_dst = data.node_pair
neg_src = data.negative_head
# Assertation
assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size
......@@ -123,7 +126,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
)
# Perform Negative sampling.
for data in negative_sampler:
pos_src, pos_dst, neg_dst = data
pos_src, pos_dst = data.node_pair
neg_dst = data.negative_tail
# Assertation
assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment