"vscode:/vscode.git/clone" did not exist on "85b7858075d214f27c8e981cd9fea28a8f95249c"
Unverified Commit c6cdeb6b authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Adapt negative sampler input (#6177)

parent 405de769
...@@ -6,7 +6,6 @@ import torch ...@@ -6,7 +6,6 @@ import torch
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
from .data_format import LinkPredictionEdgeFormat from .data_format import LinkPredictionEdgeFormat
from .link_prediction_block import LinkPredictionBlock
class NegativeSampler(Mapper): class NegativeSampler(Mapper):
...@@ -38,16 +37,18 @@ class NegativeSampler(Mapper): ...@@ -38,16 +37,18 @@ class NegativeSampler(Mapper):
self.negative_ratio = negative_ratio self.negative_ratio = negative_ratio
self.output_format = output_format self.output_format = output_format
def _sample(self, node_pairs): def _sample(self, data):
""" """
Generate a mix of positive and negative samples. Generate a mix of positive and negative samples.
Parameters Parameters
---------- ----------
node_pairs : Tuple[Tensor] or Dict[etype, Tuple[Tensor]] data : LinkPredictionBlock
A tuple of tensors or a dictionary represents source-destination An instance of 'LinkPredictionBlock' class requires the 'node_pair'
node pairs of positive edges, where positive means the edge must field. This function is responsible for generating negative edges
exist in the graph. corresponding to the positive edges defined by the 'node_pair'. In
cases where negative edges already exist, this function will
overwrite them.
Returns Returns
------- -------
...@@ -55,8 +56,7 @@ class NegativeSampler(Mapper): ...@@ -55,8 +56,7 @@ class NegativeSampler(Mapper):
An instance of 'LinkPredictionBlock' encompasses both positive and An instance of 'LinkPredictionBlock' encompasses both positive and
negative samples. negative samples.
""" """
node_pairs = data.node_pair
data = LinkPredictionBlock(node_pair=node_pairs)
if isinstance(node_pairs, Mapping): if isinstance(node_pairs, Mapping):
for etype, pos_pairs in node_pairs.items(): for etype, pos_pairs in node_pairs.items():
self._collate( self._collate(
......
...@@ -2,6 +2,11 @@ import dgl.graphbolt as gb ...@@ -2,6 +2,11 @@ import dgl.graphbolt as gb
import gb_test_utils import gb_test_utils
import pytest import pytest
import torch import torch
from torchdata.datapipes.iter import Mapper
def to_data_block(data):
return gb.LinkPredictionBlock(node_pair=data)
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20]) @pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
...@@ -17,9 +22,10 @@ def test_NegativeSampler_Independent_Format(negative_ratio): ...@@ -17,9 +22,10 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
) )
batch_size = 10 batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size) minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(minibatch_sampler, to_data_block)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
minibatch_sampler, data_block_converter,
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT, gb.LinkPredictionEdgeFormat.INDEPENDENT,
graph, graph,
...@@ -49,9 +55,10 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio): ...@@ -49,9 +55,10 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
) )
batch_size = 10 batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size) minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(minibatch_sampler, to_data_block)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
minibatch_sampler, data_block_converter,
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.CONDITIONED, gb.LinkPredictionEdgeFormat.CONDITIONED,
graph, graph,
...@@ -84,9 +91,10 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio): ...@@ -84,9 +91,10 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
) )
batch_size = 10 batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size) minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(minibatch_sampler, to_data_block)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
minibatch_sampler, data_block_converter,
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED, gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
graph, graph,
...@@ -117,9 +125,10 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio): ...@@ -117,9 +125,10 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
) )
batch_size = 10 batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size) minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(minibatch_sampler, to_data_block)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
minibatch_sampler, data_block_converter,
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED, gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
graph, graph,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment