Unverified Commit 1ff5f09f authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Remove link prediction format (#6298)

parent 09c8e8d9
...@@ -7,7 +7,6 @@ import torch ...@@ -7,7 +7,6 @@ import torch
from .._ffi import libinfo from .._ffi import libinfo
from .base import * from .base import *
from .minibatch import * from .minibatch import *
from .data_format import *
from .dataloader import * from .dataloader import *
from .dataset import * from .dataset import *
from .feature_fetcher import * from .feature_fetcher import *
......
"""Data format enums for graphbolt."""
from enum import Enum
__all__ = ["LinkPredictionEdgeFormat"]
class LinkPredictionEdgeFormat(Enum):
"""
An Enum class representing the formats of positive and negative edges used
in link prediction:
Attributes:
INDEPENDENT: Represents the 'independent' format where data is structured
as triples `(u, v, label)` indicating the source and destination nodes of
an edge, with a label (0 or 1) denoting it as negative or positive.
CONDITIONED: Represents the 'conditioned' format where data is structured
as quadruples `(u, v, neg_u, neg_v)` indicating the source and destination
nodes of positive and negative edges. And 'u' with 'v' are 1D tensors with
the same shape, while 'neg_u' and 'neg_v' are 2D tensors with the same
shape.
HEAD_CONDITIONED: Represents the 'head conditioned' format where data is
structured as triples `(u, v, neg_u)`, where '(u, v)' signifies the
source and destination nodes of positive edges, while each node in
'neg_u' collaborates with 'v' to create negative edges. And 'u' and 'v' are
1D tensors with the same shape, while 'neg_u' is a 2D tensor.
TAIL_CONDITIONED: Represents the 'tail conditioned' format where data is
structured as triples `(u, v, neg_v)`, where '(u, v)' signifies the
source and destination nodes of positive edges, while 'u' collaborates
with each node in 'neg_v' to create negative edges. And 'u' and 'v' are
1D tensors with the same shape, while 'neg_v' is a 2D tensor.
"""
INDEPENDENT = "independent"
CONDITIONED = "conditioned"
HEAD_CONDITIONED = "head_conditioned"
TAIL_CONDITIONED = "tail_conditioned"
...@@ -21,7 +21,6 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -21,7 +21,6 @@ class UniformNegativeSampler(NegativeSampler):
self, self,
datapipe, datapipe,
negative_ratio, negative_ratio,
output_format,
graph, graph,
): ):
""" """
...@@ -33,8 +32,6 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -33,8 +32,6 @@ class UniformNegativeSampler(NegativeSampler):
The datapipe. The datapipe.
negative_ratio : int negative_ratio : int
The proportion of negative samples to positive samples. The proportion of negative samples to positive samples.
output_format : LinkPredictionEdgeFormat
Determines the format of the output data.
graph : CSCSamplingGraph graph : CSCSamplingGraph
The graph on which to perform negative sampling. The graph on which to perform negative sampling.
...@@ -44,39 +41,21 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -44,39 +41,21 @@ class UniformNegativeSampler(NegativeSampler):
>>> indptr = torch.LongTensor([0, 2, 4, 5]) >>> indptr = torch.LongTensor([0, 2, 4, 5])
>>> indices = torch.LongTensor([1, 2, 0, 2, 0]) >>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices) >>> graph = gb.from_csc(indptr, indices)
>>> output_format = gb.LinkPredictionEdgeFormat.INDEPENDENT >>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> node_pairs = torch.tensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs") >>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1, ...item_set, batch_size=1,
...) ...)
>>> neg_sampler = gb.UniformNegativeSampler( >>> neg_sampler = gb.UniformNegativeSampler(
...item_sampler, 2, output_format, graph) ...item_sampler, 2, graph)
>>> for data in neg_sampler: >>> for minibatch in neg_sampler:
... print(data.node_pairs, data.negative_dsts) ... print(minibatch.negative_srcs)
... print(minibatch.negative_dsts)
... ...
(tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([1, 0, 0])) (tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([1, 0, 0]))
(tensor([1, 1, 1]), tensor([2, 1, 2]), tensor([1, 0, 0])) (tensor([1, 1, 1]), tensor([2, 1, 2]), tensor([1, 0, 0]))
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5])
>>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices)
>>> output_format = gb.LinkPredictionEdgeFormat.CONDITIONED
>>> node_pairs = torch.tensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> neg_sampler = gb.UniformNegativeSampler(
...item_sampler, 2, output_format, graph)
>>> for data in neg_sampler:
... print(data.node_pairs, data.negative_dsts)
...
(tensor([0]), tensor([1]), tensor([[0, 0]]), tensor([[0, 1]]))
(tensor([1]), tensor([2]), tensor([[1, 1]]), tensor([[0, 1]]))
""" """
super().__init__(datapipe, negative_ratio, output_format) super().__init__(datapipe, negative_ratio)
self.graph = graph self.graph = graph
def _sample_with_etype(self, node_pairs, etype=None): def _sample_with_etype(self, node_pairs, etype=None):
......
...@@ -2,12 +2,9 @@ ...@@ -2,12 +2,9 @@
from _collections_abc import Mapping from _collections_abc import Mapping
import torch
from torch.utils.data import functional_datapipe from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
from .data_format import LinkPredictionEdgeFormat
@functional_datapipe("sample_negative") @functional_datapipe("sample_negative")
class NegativeSampler(Mapper): class NegativeSampler(Mapper):
...@@ -20,7 +17,6 @@ class NegativeSampler(Mapper): ...@@ -20,7 +17,6 @@ class NegativeSampler(Mapper):
self, self,
datapipe, datapipe,
negative_ratio, negative_ratio,
output_format,
): ):
""" """
Initlization for a negative sampler. Initlization for a negative sampler.
...@@ -31,13 +27,10 @@ class NegativeSampler(Mapper): ...@@ -31,13 +27,10 @@ class NegativeSampler(Mapper):
The datapipe. The datapipe.
negative_ratio : int negative_ratio : int
The proportion of negative samples to positive samples. The proportion of negative samples to positive samples.
output_format : LinkPredictionEdgeFormat
Determines the edge format of the output minibatch.
""" """
super().__init__(datapipe, self._sample) super().__init__(datapipe, self._sample)
assert negative_ratio > 0, "Negative_ratio should be positive Integer." assert negative_ratio > 0, "Negative_ratio should be positive Integer."
self.negative_ratio = negative_ratio self.negative_ratio = negative_ratio
self.output_format = output_format
def _sample(self, minibatch): def _sample(self, minibatch):
""" """
...@@ -61,18 +54,11 @@ class NegativeSampler(Mapper): ...@@ -61,18 +54,11 @@ class NegativeSampler(Mapper):
node_pairs = minibatch.node_pairs node_pairs = minibatch.node_pairs
assert node_pairs is not None assert node_pairs is not None
if isinstance(node_pairs, Mapping): if isinstance(node_pairs, Mapping):
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT: minibatch.negative_srcs, minibatch.negative_dsts = {}, {}
minibatch.labels = {}
else:
minibatch.negative_srcs, minibatch.negative_dsts = {}, {}
for etype, pos_pairs in node_pairs.items(): for etype, pos_pairs in node_pairs.items():
self._collate( self._collate(
minibatch, self._sample_with_etype(pos_pairs, etype), etype minibatch, self._sample_with_etype(pos_pairs, etype), etype
) )
if self.output_format == LinkPredictionEdgeFormat.HEAD_CONDITIONED:
minibatch.negative_dsts = None
if self.output_format == LinkPredictionEdgeFormat.TAIL_CONDITIONED:
minibatch.negative_srcs = None
else: else:
self._collate(minibatch, self._sample_with_etype(node_pairs)) self._collate(minibatch, self._sample_with_etype(node_pairs))
return minibatch return minibatch
...@@ -112,45 +98,14 @@ class NegativeSampler(Mapper): ...@@ -112,45 +98,14 @@ class NegativeSampler(Mapper):
etype : str etype : str
Canonical edge type. Canonical edge type.
""" """
pos_src, pos_dst = (
minibatch.node_pairs[etype]
if etype is not None
else minibatch.node_pairs
)
neg_src, neg_dst = neg_pairs neg_src, neg_dst = neg_pairs
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT: if neg_src is not None:
pos_labels = torch.ones_like(pos_src) neg_src = neg_src.view(-1, self.negative_ratio)
neg_labels = torch.zeros_like(neg_src) if neg_dst is not None:
src = torch.cat([pos_src, neg_src]) neg_dst = neg_dst.view(-1, self.negative_ratio)
dst = torch.cat([pos_dst, neg_dst]) if etype is not None:
labels = torch.cat([pos_labels, neg_labels]) minibatch.negative_srcs[etype] = neg_src
if etype is not None: minibatch.negative_dsts[etype] = neg_dst
minibatch.node_pairs[etype] = (src, dst)
minibatch.labels[etype] = labels
else:
minibatch.node_pairs = (src, dst)
minibatch.labels = labels
else: else:
if self.output_format == LinkPredictionEdgeFormat.CONDITIONED: minibatch.negative_srcs = neg_src
neg_src = neg_src.view(-1, self.negative_ratio) minibatch.negative_dsts = neg_dst
neg_dst = neg_dst.view(-1, self.negative_ratio)
elif (
self.output_format == LinkPredictionEdgeFormat.HEAD_CONDITIONED
):
neg_src = neg_src.view(-1, self.negative_ratio)
neg_dst = None
elif (
self.output_format == LinkPredictionEdgeFormat.TAIL_CONDITIONED
):
neg_dst = neg_dst.view(-1, self.negative_ratio)
neg_src = None
else:
raise TypeError(
f"Unsupported output format {self.output_format}."
)
if etype is not None:
minibatch.negative_srcs[etype] = neg_src
minibatch.negative_dsts[etype] = neg_dst
else:
minibatch.negative_srcs = neg_src
minibatch.negative_dsts = neg_dst
...@@ -57,6 +57,12 @@ class SubgraphSampler(Mapper): ...@@ -57,6 +57,12 @@ class SubgraphSampler(Mapper):
has_neg_dst = neg_dst is not None has_neg_dst = neg_dst is not None
is_heterogeneous = isinstance(node_pairs, Dict) is_heterogeneous = isinstance(node_pairs, Dict)
if is_heterogeneous: if is_heterogeneous:
has_neg_src = has_neg_src and all(
item is not None for item in neg_src.values()
)
has_neg_dst = has_neg_dst and all(
item is not None for item in neg_dst.values()
)
# Collect nodes from all types of input. # Collect nodes from all types of input.
nodes = defaultdict(list) nodes = defaultdict(list)
for etype, (src, dst) in node_pairs.items(): for etype, (src, dst) in node_pairs.items():
......
...@@ -2,7 +2,6 @@ import dgl.graphbolt as gb ...@@ -2,7 +2,6 @@ import dgl.graphbolt as gb
import gb_test_utils import gb_test_utils
import pytest import pytest
import torch import torch
from torchdata.datapipes.iter import Mapper
def test_NegativeSampler_invoke(): def test_NegativeSampler_invoke():
...@@ -19,7 +18,6 @@ def test_NegativeSampler_invoke(): ...@@ -19,7 +18,6 @@ def test_NegativeSampler_invoke():
negative_sampler = gb.NegativeSampler( negative_sampler = gb.NegativeSampler(
item_sampler, item_sampler,
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
) )
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
next(iter(negative_sampler)) next(iter(negative_sampler))
...@@ -27,7 +25,6 @@ def test_NegativeSampler_invoke(): ...@@ -27,7 +25,6 @@ def test_NegativeSampler_invoke():
# Invoke NegativeSampler via functional form. # Invoke NegativeSampler via functional form.
negative_sampler = item_sampler.sample_negative( negative_sampler = item_sampler.sample_negative(
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
) )
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
next(iter(negative_sampler)) next(iter(negative_sampler))
...@@ -47,20 +44,16 @@ def test_UniformNegativeSampler_invoke(): ...@@ -47,20 +44,16 @@ def test_UniformNegativeSampler_invoke():
# Verify iteration over UniformNegativeSampler. # Verify iteration over UniformNegativeSampler.
def _verify(negative_sampler): def _verify(negative_sampler):
for data in negative_sampler: for data in negative_sampler:
src, dst = data.node_pairs
labels = data.labels
# Assertation # Assertation
assert len(src) == batch_size * (negative_ratio + 1) assert data.negative_srcs.size(0) == batch_size
assert len(dst) == batch_size * (negative_ratio + 1) assert data.negative_srcs.size(1) == negative_ratio
assert len(labels) == batch_size * (negative_ratio + 1) assert data.negative_dsts.size(0) == batch_size
assert torch.all(torch.eq(labels[:batch_size], 1)) assert data.negative_dsts.size(1) == negative_ratio
assert torch.all(torch.eq(labels[batch_size:], 0))
# Invoke UniformNegativeSampler via class constructor. # Invoke UniformNegativeSampler via class constructor.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
item_sampler, item_sampler,
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
graph, graph,
) )
_verify(negative_sampler) _verify(negative_sampler)
...@@ -68,14 +61,13 @@ def test_UniformNegativeSampler_invoke(): ...@@ -68,14 +61,13 @@ def test_UniformNegativeSampler_invoke():
# Invoke UniformNegativeSampler via functional form. # Invoke UniformNegativeSampler via functional form.
negative_sampler = item_sampler.sample_uniform_negative( negative_sampler = item_sampler.sample_uniform_negative(
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
graph, graph,
) )
_verify(negative_sampler) _verify(negative_sampler)
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20]) @pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_NegativeSampler_Independent_Format(negative_ratio): def test_Uniform_NegativeSampler(negative_ratio):
# Construct CSCSamplingGraph. # Construct CSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05) graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30 num_seeds = 30
...@@ -88,36 +80,6 @@ def test_NegativeSampler_Independent_Format(negative_ratio): ...@@ -88,36 +80,6 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
item_sampler, item_sampler,
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
graph,
)
# Perform Negative sampling.
for data in negative_sampler:
src, dst = data.node_pairs
labels = data.labels
# Assertation
assert len(src) == batch_size * (negative_ratio + 1)
assert len(dst) == batch_size * (negative_ratio + 1)
assert len(labels) == batch_size * (negative_ratio + 1)
assert torch.all(torch.eq(labels[:batch_size], 1))
assert torch.all(torch.eq(labels[batch_size:], 0))
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_NegativeSampler_Conditioned_Format(negative_ratio):
# Construct CSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
negative_ratio,
gb.LinkPredictionEdgeFormat.CONDITIONED,
graph, graph,
) )
# Perform Negative sampling. # Perform Negative sampling.
...@@ -135,64 +97,6 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio): ...@@ -135,64 +97,6 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
assert torch.equal(expected_src, neg_src) assert torch.equal(expected_src, neg_src)
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
# Construct CSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
negative_ratio,
gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
graph,
)
# Perform Negative sampling.
for data in negative_sampler:
pos_src, pos_dst = data.node_pairs
neg_src = data.negative_srcs
# Assertation
assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size
assert len(neg_src) == batch_size
assert neg_src.numel() == batch_size * negative_ratio
expected_src = pos_src.repeat(negative_ratio).view(-1, negative_ratio)
assert torch.equal(expected_src, neg_src)
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
# Construct CSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
negative_ratio,
gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
graph,
)
# Perform Negative sampling.
for data in negative_sampler:
pos_src, pos_dst = data.node_pairs
neg_dst = data.negative_dsts
# Assertation
assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size
assert len(neg_dst) == batch_size
assert neg_dst.numel() == batch_size * negative_ratio
def get_hetero_graph(): def get_hetero_graph():
# COO graph: # COO graph:
# [0, 0, 1, 1, 2, 2, 3, 3, 4, 4] # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
...@@ -215,16 +119,7 @@ def get_hetero_graph(): ...@@ -215,16 +119,7 @@ def get_hetero_graph():
) )
@pytest.mark.parametrize( def test_NegativeSampler_Hetero_Data():
"format",
[
gb.LinkPredictionEdgeFormat.INDEPENDENT,
gb.LinkPredictionEdgeFormat.CONDITIONED,
gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
],
)
def test_NegativeSampler_Hetero_Data(format):
graph = get_hetero_graph() graph = get_hetero_graph()
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
...@@ -240,5 +135,5 @@ def test_NegativeSampler_Hetero_Data(format): ...@@ -240,5 +135,5 @@ def test_NegativeSampler_Hetero_Data(format):
) )
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
negative_dp = gb.UniformNegativeSampler(item_sampler, 1, format, graph) negative_dp = gb.UniformNegativeSampler(item_sampler, 1, graph)
assert len(list(negative_dp)) == 5 assert len(list(negative_dp)) == 5
...@@ -2,7 +2,6 @@ import dgl.graphbolt as gb ...@@ -2,7 +2,6 @@ import dgl.graphbolt as gb
import gb_test_utils import gb_test_utils
import pytest import pytest
import torch import torch
import torchdata.datapipes as dp
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
...@@ -71,23 +70,14 @@ def test_SubgraphSampler_Link(labor): ...@@ -71,23 +70,14 @@ def test_SubgraphSampler_Link(labor):
assert len(list(neighbor_dp)) == 5 assert len(list(neighbor_dp)) == 5
@pytest.mark.parametrize(
"format",
[
gb.LinkPredictionEdgeFormat.INDEPENDENT,
gb.LinkPredictionEdgeFormat.CONDITIONED,
gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
],
)
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_With_Negative(format, labor): def test_SubgraphSampler_Link_With_Negative(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs") itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
negative_dp = gb.UniformNegativeSampler(item_sampler, 1, format, graph) negative_dp = gb.UniformNegativeSampler(item_sampler, 1, graph)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts) neighbor_dp = Sampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5 assert len(list(neighbor_dp)) == 5
...@@ -139,17 +129,8 @@ def test_SubgraphSampler_Link_Hetero(labor): ...@@ -139,17 +129,8 @@ def test_SubgraphSampler_Link_Hetero(labor):
assert len(list(neighbor_dp)) == 5 assert len(list(neighbor_dp)) == 5
@pytest.mark.parametrize(
"format",
[
gb.LinkPredictionEdgeFormat.INDEPENDENT,
gb.LinkPredictionEdgeFormat.CONDITIONED,
gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
],
)
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor): def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
graph = get_hetero_graph() graph = get_hetero_graph()
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
...@@ -167,7 +148,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor): ...@@ -167,7 +148,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
negative_dp = gb.UniformNegativeSampler(item_sampler, 1, format, graph) negative_dp = gb.UniformNegativeSampler(item_sampler, 1, graph)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts) neighbor_dp = Sampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5 assert len(list(neighbor_dp)) == 5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment