Unverified Commit 79a95477 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] clean up minibatch collator in testcases and docstring (#6294)

parent b4c351b4
...@@ -53,24 +53,18 @@ class NeighborSampler(SubgraphSampler): ...@@ -53,24 +53,18 @@ class NeighborSampler(SubgraphSampler):
------- -------
>>> import dgl.graphbolt as gb >>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper >>> from torchdata.datapipes.iter import Mapper
>>> def minibatch_link_collator(data):
... minibatch = gb.MiniBatch(node_pairs=data)
... return minibatch
...
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8]) >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5]) >>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.from_csc(indptr, indices) >>> graph = gb.from_csc(indptr, indices)
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT >>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2])) >>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs) >>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1, ...item_set, batch_size=1,
...) ...)
>>> minibatch_converter = Mapper(item_sampler,
...minibatch_link_collator)
>>> neg_sampler = gb.UniformNegativeSampler( >>> neg_sampler = gb.UniformNegativeSampler(
...minibatch_converter, 2, data_format, graph) ...item_sampler, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]), >>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...torch.LongTensor([15])] ...torch.LongTensor([15])]
>>> subgraph_sampler = gb.NeighborSampler( >>> subgraph_sampler = gb.NeighborSampler(
...@@ -165,24 +159,18 @@ class LayerNeighborSampler(NeighborSampler): ...@@ -165,24 +159,18 @@ class LayerNeighborSampler(NeighborSampler):
------- -------
>>> import dgl.graphbolt as gb >>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper >>> from torchdata.datapipes.iter import Mapper
>>> def minibatch_link_collator(data):
... minibatch = gb.MiniBatch(node_pairs=data)
... return minibatch
...
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8]) >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5]) >>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.from_csc(indptr, indices) >>> graph = gb.from_csc(indptr, indices)
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT >>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2])) >>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs) >>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1, ...item_set, batch_size=1,
...) ...)
>>> minibatch_converter = Mapper(item_sampler,
...minibatch_link_collator)
>>> neg_sampler = gb.UniformNegativeSampler( >>> neg_sampler = gb.UniformNegativeSampler(
...minibatch_converter, 2, data_format, graph) ...item_sampler, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]), >>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...torch.LongTensor([15])] ...torch.LongTensor([15])]
>>> subgraph_sampler = gb.LayerNeighborSampler( >>> subgraph_sampler = gb.LayerNeighborSampler(
......
...@@ -42,15 +42,15 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -42,15 +42,15 @@ class UniformNegativeSampler(NegativeSampler):
>>> indices = torch.LongTensor([1, 2, 0, 2, 0]) >>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices) >>> graph = gb.from_csc(indptr, indices)
>>> output_format = gb.LinkPredictionEdgeFormat.INDEPENDENT >>> output_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2])) >>> node_pairs = torch.tensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs) >>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1, ...item_set, batch_size=1,
...) ...)
>>> neg_sampler = gb.UniformNegativeSampler( >>> neg_sampler = gb.UniformNegativeSampler(
...item_sampler, 2, output_format, graph) ...item_sampler, 2, output_format, graph)
>>> for data in neg_sampler: >>> for data in neg_sampler:
... print(data) ... print(data.node_pairs, data.negative_dsts)
... ...
(tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([1, 0, 0])) (tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([1, 0, 0]))
(tensor([1, 1, 1]), tensor([2, 1, 2]), tensor([1, 0, 0])) (tensor([1, 1, 1]), tensor([2, 1, 2]), tensor([1, 0, 0]))
...@@ -60,18 +60,18 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -60,18 +60,18 @@ class UniformNegativeSampler(NegativeSampler):
>>> indices = torch.LongTensor([1, 2, 0, 2, 0]) >>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices) >>> graph = gb.from_csc(indptr, indices)
>>> output_format = gb.LinkPredictionEdgeFormat.CONDITIONED >>> output_format = gb.LinkPredictionEdgeFormat.CONDITIONED
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2])) >>> node_pairs = torch.tensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs) >>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1, ...item_set, batch_size=1,
...) ...)
>>> neg_sampler = gb.UniformNegativeSampler( >>> neg_sampler = gb.UniformNegativeSampler(
...item_sampler, 2, output_format, graph) ...item_sampler, 2, output_format, graph)
>>> for data in neg_sampler: >>> for data in neg_sampler:
... print(data) ... print(data.node_pairs, data.negative_dsts)
... ...
(tensor([0]), tensor([1]), tensor([[0, 0]]), tensor([[2, 1]])) (tensor([0]), tensor([1]), tensor([[0, 0]]), tensor([[0, 1]]))
(tensor([1]), tensor([2]), tensor([[1, 1]]), tensor([[1, 2]])) (tensor([1]), tensor([2]), tensor([[1, 1]]), tensor([[0, 1]]))
""" """
super().__init__(datapipe, negative_ratio, output_format) super().__init__(datapipe, negative_ratio, output_format)
self.graph = graph self.graph = graph
......
...@@ -8,16 +8,6 @@ import scipy.sparse as sp ...@@ -8,16 +8,6 @@ import scipy.sparse as sp
import torch import torch
def minibatch_node_collator(data):
minibatch = gb.MiniBatch(seed_nodes=data)
return minibatch
def minibatch_link_collator(data):
minibatch = gb.MiniBatch(node_pairs=data)
return minibatch
def rand_csc_graph(N, density): def rand_csc_graph(N, density):
adj = sp.random(N, N, density) adj = sp.random(N, N, density)
adj = adj + adj.T adj = adj + adj.T
......
...@@ -11,19 +11,13 @@ def test_NegativeSampler_Independent_Format(negative_ratio): ...@@ -11,19 +11,13 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
graph = gb_test_utils.rand_csc_graph(100, 0.05) graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30 num_seeds = 30
item_set = gb.ItemSet( item_set = gb.ItemSet(
( torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs"
torch.arange(0, num_seeds),
torch.arange(num_seeds, num_seeds * 2),
)
) )
batch_size = 10 batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_link_collator
)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
minibatch_converter, item_sampler,
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT, gb.LinkPredictionEdgeFormat.INDEPENDENT,
graph, graph,
...@@ -46,19 +40,13 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio): ...@@ -46,19 +40,13 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
graph = gb_test_utils.rand_csc_graph(100, 0.05) graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30 num_seeds = 30
item_set = gb.ItemSet( item_set = gb.ItemSet(
( torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs"
torch.arange(0, num_seeds),
torch.arange(num_seeds, num_seeds * 2),
)
) )
batch_size = 10 batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_link_collator
)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
minibatch_converter, item_sampler,
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.CONDITIONED, gb.LinkPredictionEdgeFormat.CONDITIONED,
graph, graph,
...@@ -84,19 +72,13 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio): ...@@ -84,19 +72,13 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
graph = gb_test_utils.rand_csc_graph(100, 0.05) graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30 num_seeds = 30
item_set = gb.ItemSet( item_set = gb.ItemSet(
( torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs"
torch.arange(0, num_seeds),
torch.arange(num_seeds, num_seeds * 2),
)
) )
batch_size = 10 batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_link_collator
)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
minibatch_converter, item_sampler,
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED, gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
graph, graph,
...@@ -120,19 +102,13 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio): ...@@ -120,19 +102,13 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
graph = gb_test_utils.rand_csc_graph(100, 0.05) graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30 num_seeds = 30
item_set = gb.ItemSet( item_set = gb.ItemSet(
( torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs"
torch.arange(0, num_seeds),
torch.arange(num_seeds, num_seeds * 2),
)
) )
batch_size = 10 batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_link_collator
)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
minibatch_converter, item_sampler,
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED, gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
graph, graph,
...@@ -184,25 +160,18 @@ def test_NegativeSampler_Hetero_Data(format): ...@@ -184,25 +160,18 @@ def test_NegativeSampler_Hetero_Data(format):
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
"n1:e1:n2": gb.ItemSet( "n1:e1:n2": gb.ItemSet(
( torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
torch.LongTensor([0, 0, 1, 1]), names="node_pairs",
torch.LongTensor([0, 2, 0, 1]),
)
), ),
"n2:e2:n1": gb.ItemSet( "n2:e2:n1": gb.ItemSet(
( torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
torch.LongTensor([0, 0, 1, 1, 2, 2]), names="node_pairs",
torch.LongTensor([0, 1, 1, 0, 0, 1]),
)
), ),
} }
) )
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
minibatch_converter = Mapper( negative_dp = gb.UniformNegativeSampler(item_sampler, 1, format, graph)
item_sampler_dp, gb_test_utils.minibatch_link_collator for neg in negative_dp:
) print(neg)
negative_dp = gb.UniformNegativeSampler(
minibatch_converter, 1, format, graph
)
assert len(list(negative_dp)) == 5 assert len(list(negative_dp)) == 5
...@@ -15,14 +15,11 @@ def test_FeatureFetcher_homo(): ...@@ -15,14 +15,11 @@ def test_FeatureFetcher_homo():
features[keys[1]] = gb.TorchBasedFeature(b) features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features) feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10)) itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
minibatch_converter = Mapper( sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
item_sampler_dp, gb_test_utils.minibatch_node_collator
)
sampler_dp = gb.NeighborSampler(minibatch_converter, graph, fanouts)
fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"]) fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"])
assert len(list(fetcher_dp)) == 5 assert len(list(fetcher_dp)) == 5
...@@ -99,17 +96,14 @@ def test_FeatureFetcher_hetero(): ...@@ -99,17 +96,14 @@ def test_FeatureFetcher_hetero():
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
"n1": gb.ItemSet(torch.LongTensor([0, 1])), "n1": gb.ItemSet(torch.LongTensor([0, 1]), names="seed_nodes"),
"n2": gb.ItemSet(torch.LongTensor([0, 1, 2])), "n2": gb.ItemSet(torch.LongTensor([0, 1, 2]), names="seed_nodes"),
} }
) )
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
minibatch_converter = Mapper( sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
item_sampler_dp, gb_test_utils.minibatch_node_collator
)
sampler_dp = gb.NeighborSampler(minibatch_converter, graph, fanouts)
fetcher_dp = gb.FeatureFetcher( fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]} sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
) )
......
...@@ -13,7 +13,7 @@ from torchdata.datapipes.iter import Mapper ...@@ -13,7 +13,7 @@ from torchdata.datapipes.iter import Mapper
def test_DataLoader(): def test_DataLoader():
N = 40 N = 40
B = 4 B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N)) itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
graph = gb_test_utils.rand_csc_graph(200, 0.15) graph = gb_test_utils.rand_csc_graph(200, 0.15)
features = {} features = {}
keys = [("node", None, "a"), ("node", None, "b")] keys = [("node", None, "a"), ("node", None, "b")]
...@@ -22,11 +22,8 @@ def test_DataLoader(): ...@@ -22,11 +22,8 @@ def test_DataLoader():
feature_store = dgl.graphbolt.BasicFeatureStore(features) feature_store = dgl.graphbolt.BasicFeatureStore(features)
item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B) item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_node_collator
)
subgraph_sampler = dgl.graphbolt.NeighborSampler( subgraph_sampler = dgl.graphbolt.NeighborSampler(
minibatch_converter, item_sampler,
graph, graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)], fanouts=[torch.LongTensor([2]) for _ in range(2)],
) )
......
...@@ -10,7 +10,7 @@ from torchdata.datapipes.iter import Mapper ...@@ -10,7 +10,7 @@ from torchdata.datapipes.iter import Mapper
def test_DataLoader(): def test_DataLoader():
N = 32 N = 32
B = 4 B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N)) itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
graph = gb_test_utils.rand_csc_graph(200, 0.15) graph = gb_test_utils.rand_csc_graph(200, 0.15)
features = {} features = {}
...@@ -20,11 +20,8 @@ def test_DataLoader(): ...@@ -20,11 +20,8 @@ def test_DataLoader():
feature_store = dgl.graphbolt.BasicFeatureStore(features) feature_store = dgl.graphbolt.BasicFeatureStore(features)
item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B) item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_node_collator
)
subgraph_sampler = dgl.graphbolt.NeighborSampler( subgraph_sampler = dgl.graphbolt.NeighborSampler(
minibatch_converter, item_sampler,
graph, graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)], fanouts=[torch.LongTensor([2]) for _ in range(2)],
) )
......
...@@ -9,15 +9,12 @@ from torchdata.datapipes.iter import Mapper ...@@ -9,15 +9,12 @@ from torchdata.datapipes.iter import Mapper
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node(labor): def test_SubgraphSampler_Node(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(torch.arange(10)) itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_node_collator
)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
sampler_dp = Sampler(minibatch_converter, graph, fanouts) sampler_dp = Sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 5 assert len(list(sampler_dp)) == 5
...@@ -29,20 +26,12 @@ def to_link_batch(data): ...@@ -29,20 +26,12 @@ def to_link_batch(data):
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link(labor): def test_SubgraphSampler_Link(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet( itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
( item_sampler = gb.ItemSampler(itemset, batch_size=2)
torch.arange(0, 10),
torch.arange(10, 20),
)
)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(minibatch_converter, graph, fanouts) neighbor_dp = Sampler(item_sampler, graph, fanouts)
assert len(list(neighbor_dp)) == 5 assert len(list(neighbor_dp)) == 5
...@@ -58,21 +47,11 @@ def test_SubgraphSampler_Link(labor): ...@@ -58,21 +47,11 @@ def test_SubgraphSampler_Link(labor):
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_With_Negative(format, labor): def test_SubgraphSampler_Link_With_Negative(format, labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet( itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
( item_sampler = gb.ItemSampler(itemset, batch_size=2)
torch.arange(0, 10),
torch.arange(10, 20),
)
)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
minibatch_converter = Mapper( negative_dp = gb.UniformNegativeSampler(item_sampler, 1, format, graph)
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
negative_dp = gb.UniformNegativeSampler(
minibatch_converter, 1, format, graph
)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts) neighbor_dp = Sampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5 assert len(list(neighbor_dp)) == 5
...@@ -106,28 +85,21 @@ def test_SubgraphSampler_Link_Hetero(labor): ...@@ -106,28 +85,21 @@ def test_SubgraphSampler_Link_Hetero(labor):
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
"n1:e1:n2": gb.ItemSet( "n1:e1:n2": gb.ItemSet(
( torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
torch.LongTensor([0, 0, 1, 1]), names="node_pairs",
torch.LongTensor([0, 2, 0, 1]),
)
), ),
"n2:e2:n1": gb.ItemSet( "n2:e2:n1": gb.ItemSet(
( torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
torch.LongTensor([0, 0, 1, 1, 2, 2]), names="node_pairs",
torch.LongTensor([0, 1, 1, 0, 0, 1]),
)
), ),
} }
) )
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(minibatch_converter, graph, fanouts) neighbor_dp = Sampler(item_sampler, graph, fanouts)
assert len(list(neighbor_dp)) == 5 assert len(list(neighbor_dp)) == 5
...@@ -146,29 +118,20 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor): ...@@ -146,29 +118,20 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
"n1:e1:n2": gb.ItemSet( "n1:e1:n2": gb.ItemSet(
( torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
torch.LongTensor([0, 0, 1, 1]), names="node_pairs",
torch.LongTensor([0, 2, 0, 1]),
)
), ),
"n2:e2:n1": gb.ItemSet( "n2:e2:n1": gb.ItemSet(
( torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
torch.LongTensor([0, 0, 1, 1, 2, 2]), names="node_pairs",
torch.LongTensor([0, 1, 1, 0, 0, 1]),
)
), ),
} }
) )
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
minibatch_converter = Mapper( negative_dp = gb.UniformNegativeSampler(item_sampler, 1, format, graph)
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
negative_dp = gb.UniformNegativeSampler(
minibatch_converter, 1, format, graph
)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts) neighbor_dp = Sampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5 assert len(list(neighbor_dp)) == 5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment