Unverified Commit 5da7d391 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt]Add negative sampler udf (#6053)

parent 8c213ef1
......@@ -15,6 +15,8 @@ from .impl import *
from .dataloader import *
from .subgraph_sampler import *
from .sampled_subgraph import *
from .link_data_format import *
from .negative_sampler import *
from .utils import unique_and_compact_node_pairs
......
......@@ -4,3 +4,4 @@ from .ondisk_metadata import *
from .torch_based_feature_store import *
from .csc_sampling_graph import *
from .sampled_subgraph_impl import *
from .uniform_negative_sampler import *
"""Uniform negative sampler for GraphBolt."""
from ..negative_sampler import NegativeSampler
class UniformNegativeSampler(NegativeSampler):
"""
Negative samplers randomly select negative destination nodes for each
source node based on a uniform distribution. It's important to note that
the term 'negative' refers to false negatives, indicating that the sampled
pairs are not ensured to be absent in the graph.
For each edge ``(u, v)``, it is supposed to generate `negative_ratio` pairs
of negative edges ``(u, v')``, where ``v'`` is chosen uniformly from all
the nodes in the graph.
"""
def __init__(
self,
datapipe,
negative_ratio,
link_data_format,
graph,
):
"""
Initlization for a uniform negative sampler.
Parameters
----------
datapipe : DataPipe
The datapipe.
negative_ratio : int
The proportion of negative samples to positive samples.
link_data_format : LinkDataFormat
Determines the format of the output data:
- Conditioned format: Outputs data as quadruples
`[u, v, [negative heads], [negative tails]]`. Here, 'u' and 'v'
are the source and destination nodes of positive edges, while
'negative heads' and 'negative tails' refer to the source and
destination nodes of negative edges.
- Independent format: Outputs data as triples `[u, v, label]`.
In this case, 'u' and 'v' are the source and destination nodes
of an edge, and 'label' indicates whether the edge is negative
(0) or positive (1).
graph : CSCSamplingGraph
The graph on which to perform negative sampling.
Examples
--------
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5])
>>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices)
>>> link_data_format = gb.LinkDataFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler(
...item_set, batch_size=1,
...)
>>> neg_sampler = gb.UniformNegativeSampler(
...minibatch_sampler, 2, link_data_format, graph)
>>> for data in neg_sampler:
... print(data)
...
(tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([1, 0, 0]))
(tensor([1, 1, 1]), tensor([2, 1, 2]), tensor([1, 0, 0]))
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5])
>>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices)
>>> link_data_format = gb.LinkDataFormat.CONDITIONED
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler(
...item_set, batch_size=1,
...)
>>> neg_sampler = gb.UniformNegativeSampler(
...minibatch_sampler, 2, link_data_format, graph)
>>> for data in neg_sampler:
... print(data)
...
(tensor([0]), tensor([1]), tensor([[0, 0]]), tensor([[2, 1]]))
(tensor([1]), tensor([2]), tensor([[1, 1]]), tensor([[1, 2]]))
"""
super().__init__(datapipe, negative_ratio, link_data_format)
self.graph = graph
def _sample_with_etype(self, node_pairs, etype=None):
return self.graph.sample_negative_edges_uniform(
etype,
node_pairs,
self.negative_ratio,
)
"""Linked data format."""
from enum import Enum
__all__ = ["LinkDataFormat"]
class LinkDataFormat(Enum):
"""
An Enum class representing the two data formats used in link prediction:
Attributes:
CONDITIONED: Represents the 'conditioned' format where data is
structured as quadruples `[u, v, [negative heads], [negative tails]]`
indicating the source and destination nodes of positive and negative edges.
INDEPENDENT: Represents the 'independent' format where data is structured
as triples `[u, v, label]` indicating the source and destination nodes of
an edge, with a label (0 or 1) denoting it as negative or positive.
"""
CONDITIONED = "conditioned"
INDEPENDENT = "independent"
"""Negative samplers."""
from _collections_abc import Mapping
import torch
from torchdata.datapipes.iter import Mapper
from .link_data_format import LinkDataFormat
class NegativeSampler(Mapper):
"""
A negative sampler used to generate negative samples and return
a mix of positive and negative samples.
"""
def __init__(
self,
datapipe,
negative_ratio,
link_data_format,
):
"""
Initlization for a negative sampler.
Parameters
----------
datapipe : DataPipe
The datapipe.
negative_ratio : int
The proportion of negative samples to positive samples.
link_data_format : LinkDataFormat
Determines the format of the output data:
- Conditioned format: Outputs data as quadruples
`[u, v, [negative heads], [negative tails]]`. Here, 'u' and 'v'
are the source and destination nodes of positive edges, while
'negative heads' and 'negative tails' refer to the source and
destination nodes of negative edges.
- Independent format: Outputs data as triples `[u, v, label]`.
In this case, 'u' and 'v' are the source and destination nodes
of an edge, and 'label' indicates whether the edge is negative
(0) or positive (1).
"""
super().__init__(datapipe, self._sample)
assert negative_ratio > 0, "Negative_ratio should be positive Integer."
self.negative_ratio = negative_ratio
self.link_data_format = link_data_format
def _sample(self, node_pairs):
"""
Generate a mix of positive and negative samples.
Parameters
----------
node_pairs : Tuple[Tensor] or Dict[etype, Tuple[Tensor]]
A tuple of tensors or a dictionary represents source-destination
node pairs of positive edges, where positive means the edge must
exist in the graph.
Returns
-------
Tuple[Tensor] or Dict[etype, Tuple[Tensor]]
A collection of edges or a dictionary that maps etypes to edges,
which includes both positive and negative samples.
"""
if isinstance(node_pairs, Mapping):
return {
etype: self._collate(
pos_pairs, self._sample_with_etype(pos_pairs, etype)
)
for etype, pos_pairs in node_pairs.items()
}
else:
return self._collate(
node_pairs, self._sample_with_etype(node_pairs, None)
)
def _sample_with_etype(self, node_pairs, etype=None):
"""Generate negative pairs for a given etype form positive pairs
for a given etype.
Parameters
----------
node_pairs : Tuple[Tensor]
A tuple of tensors or a dictionary represents source-destination
node pairs of positive edges, where positive means the edge must
exist in the graph.
etype : (str, str, str)
Canonical edge type.
Returns
-------
Tuple[Tensor]
A collection of negative node pairs.
"""
def _collate(self, pos_pairs, neg_pairs):
"""Collates positive and negative samples.
Parameters
----------
pos_pairs : Tuple[Tensor]
A tuple of tensors represents source-destination node pairs of
positive edges, where positive means the edge must exist in
the graph.
neg_pairs : Tuple[Tensor]
A tuple of tensors represents source-destination node pairs of
negative edges, where negative means the edge may not exist in
the graph.
Returns
-------
Tuple[Tensor]
A mixed collection of positive and negative node pairs.
"""
if self.link_data_format == LinkDataFormat.INDEPENDENT:
pos_src, pos_dst = pos_pairs
neg_src, neg_dst = neg_pairs
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
src = torch.cat([pos_src, neg_src])
dst = torch.cat([pos_dst, neg_dst])
label = torch.cat([pos_label, neg_label])
return (src, dst, label)
elif self.link_data_format == LinkDataFormat.CONDITIONED:
pos_src, pos_dst = pos_pairs
neg_src, neg_dst = neg_pairs
neg_src = neg_src.view(-1, self.negative_ratio)
neg_dst = neg_dst.view(-1, self.negative_ratio)
return (pos_src, pos_dst, neg_src, neg_dst)
else:
raise ValueError("Unsupported link data format.")
import dgl.graphbolt as gb
import gb_test_utils
import pytest
import torch
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_NegativeSampler_Independent_Format(negative_ratio):
# Construct CSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30
item_set = gb.ItemSet(
(
torch.arange(0, num_seeds),
torch.arange(num_seeds, num_seeds * 2),
)
)
batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
minibatch_sampler,
negative_ratio,
gb.LinkDataFormat.INDEPENDENT,
graph,
)
# Perform Negative sampling.
for data in negative_sampler:
src, dst, label = data
# Assertation
assert len(src) == batch_size * (negative_ratio + 1)
assert len(dst) == batch_size * (negative_ratio + 1)
assert len(label) == batch_size * (negative_ratio + 1)
assert torch.all(torch.eq(label[:batch_size], 1))
assert torch.all(torch.eq(label[batch_size:], 0))
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_NegativeSampler_Conditioned_Format(negative_ratio):
# Construct CSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30
item_set = gb.ItemSet(
(
torch.arange(0, num_seeds),
torch.arange(num_seeds, num_seeds * 2),
)
)
batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
minibatch_sampler,
negative_ratio,
gb.LinkDataFormat.CONDITIONED,
graph,
)
# Perform Negative sampling.
for data in negative_sampler:
pos_src, pos_dst, neg_src, neg_dst = data
# Assertation
assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size
assert len(neg_src) == batch_size
assert len(neg_dst) == batch_size
assert neg_src.numel() == batch_size * negative_ratio
assert neg_dst.numel() == batch_size * negative_ratio
expected_src = pos_src.repeat(negative_ratio).view(-1, negative_ratio)
assert torch.equal(expected_src, neg_src)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment