"...pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "85c030f67332a60b6dd8d4259c4d233767517c2c"
Unverified Commit 44f4b0e2 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt]Adapt negative sampler (#6165)

parent ceb25724
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
from .data_format import LinkPredictionEdgeFormat from .data_format import LinkPredictionEdgeFormat
from .link_prediction_block import LinkPredictionBlock
class NegativeSampler(Mapper): class NegativeSampler(Mapper):
...@@ -50,21 +51,20 @@ class NegativeSampler(Mapper): ...@@ -50,21 +51,20 @@ class NegativeSampler(Mapper):
Returns Returns
------- -------
Tuple[Tensor] or Dict[etype, Tuple[Tensor]] LinkPredictionBlock
A collection of edges or a dictionary that maps etypes to edges, An instance of 'LinkPredictionBlock' encompasses both positive and
which includes both positive and negative samples. negative samples.
""" """
data = LinkPredictionBlock(node_pair=node_pairs)
if isinstance(node_pairs, Mapping): if isinstance(node_pairs, Mapping):
return { for etype, pos_pairs in node_pairs.items():
etype: self._collate( self._collate(
pos_pairs, self._sample_with_etype(pos_pairs, etype) data, self._sample_with_etype(pos_pairs, etype), etype
) )
for etype, pos_pairs in node_pairs.items()
}
else: else:
return self._collate( self._collate(data, self._sample_with_etype(node_pairs))
node_pairs, self._sample_with_etype(node_pairs, None) return data
)
def _sample_with_etype(self, node_pairs, etype=None): def _sample_with_etype(self, node_pairs, etype=None):
"""Generate negative pairs for a given etype form positive pairs """Generate negative pairs for a given etype form positive pairs
...@@ -86,49 +86,56 @@ class NegativeSampler(Mapper): ...@@ -86,49 +86,56 @@ class NegativeSampler(Mapper):
""" """
raise NotImplementedError raise NotImplementedError
def _collate(self, pos_pairs, neg_pairs): def _collate(self, data, neg_pairs, etype=None):
"""Collates positive and negative samples. """Collates positive and negative samples into data.
Parameters Parameters
---------- ----------
pos_pairs : Tuple[Tensor] data : LinkPredictionBlock
A tuple of tensors represents source-destination node pairs of The input data, which contains positive node pairs, will be filled
positive edges, where positive means the edge must exist in with negative information in this function.
the graph.
neg_pairs : Tuple[Tensor] neg_pairs : Tuple[Tensor]
A tuple of tensors represents source-destination node pairs of A tuple of tensors represents source-destination node pairs of
negative edges, where negative means the edge may not exist in negative edges, where negative means the edge may not exist in
the graph. the graph.
etype : (str, str, str)
Returns Canonical edge type.
-------
Tuple[Tensor]
A mixed collection of positive and negative node pairs.
""" """
pos_src, pos_dst = data.node_pair
neg_src, neg_dst = neg_pairs
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT: if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
pos_src, pos_dst = pos_pairs
neg_src, neg_dst = neg_pairs
pos_label = torch.ones_like(pos_src) pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src) neg_label = torch.zeros_like(neg_src)
src = torch.cat([pos_src, neg_src]) src = torch.cat([pos_src, neg_src])
dst = torch.cat([pos_dst, neg_dst]) dst = torch.cat([pos_dst, neg_dst])
label = torch.cat([pos_label, neg_label]) label = torch.cat([pos_label, neg_label])
return (src, dst, label) if etype:
elif self.output_format == LinkPredictionEdgeFormat.CONDITIONED: data.node_pair[etype] = (src, dst)
pos_src, pos_dst = pos_pairs data.label[etype] = label
neg_src, neg_dst = neg_pairs else:
neg_src = neg_src.view(-1, self.negative_ratio) data.node_pair = (src, dst)
neg_dst = neg_dst.view(-1, self.negative_ratio) data.label = label
return (pos_src, pos_dst, neg_src, neg_dst)
elif self.output_format == LinkPredictionEdgeFormat.HEAD_CONDITIONED:
pos_src, pos_dst = pos_pairs
neg_src, _ = neg_pairs
neg_src = neg_src.view(-1, self.negative_ratio)
return (pos_src, pos_dst, neg_src)
elif self.output_format == LinkPredictionEdgeFormat.TAIL_CONDITIONED:
pos_src, pos_dst = pos_pairs
_, neg_dst = neg_pairs
neg_dst = neg_dst.view(-1, self.negative_ratio)
return (pos_src, pos_dst, neg_dst)
else: else:
raise ValueError("Unsupported output format.") if self.output_format == LinkPredictionEdgeFormat.CONDITIONED:
neg_src = neg_src.view(-1, self.negative_ratio)
neg_dst = neg_dst.view(-1, self.negative_ratio)
elif (
self.output_format == LinkPredictionEdgeFormat.HEAD_CONDITIONED
):
neg_src = neg_src.view(-1, self.negative_ratio)
neg_dst = None
elif (
self.output_format == LinkPredictionEdgeFormat.TAIL_CONDITIONED
):
neg_dst = neg_dst.view(-1, self.negative_ratio)
neg_src = None
else:
raise TypeError(
f"Unsupported output format {self.output_format}."
)
if etype:
data.negative_head[etype] = neg_src
data.negative_tail[etype] = neg_dst
else:
data.negative_head = neg_src
data.negative_tail = neg_dst
...@@ -26,7 +26,8 @@ def test_NegativeSampler_Independent_Format(negative_ratio): ...@@ -26,7 +26,8 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
) )
# Perform Negative sampling. # Perform Negative sampling.
for data in negative_sampler: for data in negative_sampler:
src, dst, label = data src, dst = data.node_pair
label = data.label
# Assertation # Assertation
assert len(src) == batch_size * (negative_ratio + 1) assert len(src) == batch_size * (negative_ratio + 1)
assert len(dst) == batch_size * (negative_ratio + 1) assert len(dst) == batch_size * (negative_ratio + 1)
...@@ -57,7 +58,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio): ...@@ -57,7 +58,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
) )
# Perform Negative sampling. # Perform Negative sampling.
for data in negative_sampler: for data in negative_sampler:
pos_src, pos_dst, neg_src, neg_dst = data pos_src, pos_dst = data.node_pair
neg_src, neg_dst = data.negative_head, data.negative_tail
# Assertation # Assertation
assert len(pos_src) == batch_size assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size assert len(pos_dst) == batch_size
...@@ -91,7 +93,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio): ...@@ -91,7 +93,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
) )
# Perform Negative sampling. # Perform Negative sampling.
for data in negative_sampler: for data in negative_sampler:
pos_src, pos_dst, neg_src = data pos_src, pos_dst = data.node_pair
neg_src = data.negative_head
# Assertation # Assertation
assert len(pos_src) == batch_size assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size assert len(pos_dst) == batch_size
...@@ -123,7 +126,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio): ...@@ -123,7 +126,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
) )
# Perform Negative sampling. # Perform Negative sampling.
for data in negative_sampler: for data in negative_sampler:
pos_src, pos_dst, neg_dst = data pos_src, pos_dst = data.node_pair
neg_dst = data.negative_tail
# Assertation # Assertation
assert len(pos_src) == batch_size assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size assert len(pos_dst) == batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment