Unverified Commit c6cdeb6b authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Adapt negative sampler input (#6177)

parent 405de769
......@@ -6,7 +6,6 @@ import torch
from torchdata.datapipes.iter import Mapper
from .data_format import LinkPredictionEdgeFormat
from .link_prediction_block import LinkPredictionBlock
class NegativeSampler(Mapper):
......@@ -38,16 +37,18 @@ class NegativeSampler(Mapper):
self.negative_ratio = negative_ratio
self.output_format = output_format
def _sample(self, node_pairs):
def _sample(self, data):
"""
Generate a mix of positive and negative samples.
Parameters
----------
node_pairs : Tuple[Tensor] or Dict[etype, Tuple[Tensor]]
A tuple of tensors or a dictionary represents source-destination
node pairs of positive edges, where positive means the edge must
exist in the graph.
data : LinkPredictionBlock
An instance of 'LinkPredictionBlock' class requires the 'node_pair'
field. This function is responsible for generating negative edges
corresponding to the positive edges defined by the 'node_pair'. In
cases where negative edges already exist, this function will
overwrite them.
Returns
-------
......@@ -55,8 +56,7 @@ class NegativeSampler(Mapper):
An instance of 'LinkPredictionBlock' encompasses both positive and
negative samples.
"""
data = LinkPredictionBlock(node_pair=node_pairs)
node_pairs = data.node_pair
if isinstance(node_pairs, Mapping):
for etype, pos_pairs in node_pairs.items():
self._collate(
......
......@@ -2,6 +2,11 @@ import dgl.graphbolt as gb
import gb_test_utils
import pytest
import torch
from torchdata.datapipes.iter import Mapper
def to_data_block(data):
return gb.LinkPredictionBlock(node_pair=data)
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
......@@ -17,9 +22,10 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
)
batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(minibatch_sampler, to_data_block)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
minibatch_sampler,
data_block_converter,
negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
graph,
......@@ -49,9 +55,10 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
)
batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(minibatch_sampler, to_data_block)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
minibatch_sampler,
data_block_converter,
negative_ratio,
gb.LinkPredictionEdgeFormat.CONDITIONED,
graph,
......@@ -84,9 +91,10 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
)
batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(minibatch_sampler, to_data_block)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
minibatch_sampler,
data_block_converter,
negative_ratio,
gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
graph,
......@@ -117,9 +125,10 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
)
batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(minibatch_sampler, to_data_block)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
minibatch_sampler,
data_block_converter,
negative_ratio,
gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
graph,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment