Unverified Commit 79a95477 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] clean up minibatch collator in testcases and docstring (#6294)

parent b4c351b4
......@@ -53,24 +53,18 @@ class NeighborSampler(SubgraphSampler):
-------
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> def minibatch_link_collator(data):
... minibatch = gb.MiniBatch(node_pairs=data)
... return minibatch
...
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.from_csc(indptr, indices)
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> minibatch_converter = Mapper(item_sampler,
...minibatch_link_collator)
>>> neg_sampler = gb.UniformNegativeSampler(
...minibatch_converter, 2, data_format, graph)
...item_sampler, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...torch.LongTensor([15])]
>>> subgraph_sampler = gb.NeighborSampler(
......@@ -165,24 +159,18 @@ class LayerNeighborSampler(NeighborSampler):
-------
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> def minibatch_link_collator(data):
... minibatch = gb.MiniBatch(node_pairs=data)
... return minibatch
...
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.from_csc(indptr, indices)
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> minibatch_converter = Mapper(item_sampler,
...minibatch_link_collator)
>>> neg_sampler = gb.UniformNegativeSampler(
...minibatch_converter, 2, data_format, graph)
...item_sampler, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...torch.LongTensor([15])]
>>> subgraph_sampler = gb.LayerNeighborSampler(
......
......@@ -42,15 +42,15 @@ class UniformNegativeSampler(NegativeSampler):
>>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices)
>>> output_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> node_pairs = torch.tensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> neg_sampler = gb.UniformNegativeSampler(
...item_sampler, 2, output_format, graph)
>>> for data in neg_sampler:
... print(data)
... print(data.node_pairs, data.negative_dsts)
...
(tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([1, 0, 0]))
(tensor([1, 1, 1]), tensor([2, 1, 2]), tensor([1, 0, 0]))
......@@ -60,18 +60,18 @@ class UniformNegativeSampler(NegativeSampler):
>>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices)
>>> output_format = gb.LinkPredictionEdgeFormat.CONDITIONED
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> node_pairs = torch.tensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> neg_sampler = gb.UniformNegativeSampler(
...item_sampler, 2, output_format, graph)
>>> for data in neg_sampler:
... print(data)
... print(data.node_pairs, data.negative_dsts)
...
(tensor([0]), tensor([1]), tensor([[0, 0]]), tensor([[2, 1]]))
(tensor([1]), tensor([2]), tensor([[1, 1]]), tensor([[1, 2]]))
(tensor([0]), tensor([1]), tensor([[0, 0]]), tensor([[0, 1]]))
(tensor([1]), tensor([2]), tensor([[1, 1]]), tensor([[0, 1]]))
"""
super().__init__(datapipe, negative_ratio, output_format)
self.graph = graph
......
......@@ -8,16 +8,6 @@ import scipy.sparse as sp
import torch
def minibatch_node_collator(data):
minibatch = gb.MiniBatch(seed_nodes=data)
return minibatch
def minibatch_link_collator(data):
minibatch = gb.MiniBatch(node_pairs=data)
return minibatch
def rand_csc_graph(N, density):
adj = sp.random(N, N, density)
adj = adj + adj.T
......
......@@ -11,19 +11,13 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30
item_set = gb.ItemSet(
(
torch.arange(0, num_seeds),
torch.arange(num_seeds, num_seeds * 2),
)
torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_link_collator
)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
minibatch_converter,
item_sampler,
negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
graph,
......@@ -46,19 +40,13 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30
item_set = gb.ItemSet(
(
torch.arange(0, num_seeds),
torch.arange(num_seeds, num_seeds * 2),
)
torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_link_collator
)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
minibatch_converter,
item_sampler,
negative_ratio,
gb.LinkPredictionEdgeFormat.CONDITIONED,
graph,
......@@ -84,19 +72,13 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30
item_set = gb.ItemSet(
(
torch.arange(0, num_seeds),
torch.arange(num_seeds, num_seeds * 2),
)
torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_link_collator
)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
minibatch_converter,
item_sampler,
negative_ratio,
gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
graph,
......@@ -120,19 +102,13 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30
item_set = gb.ItemSet(
(
torch.arange(0, num_seeds),
torch.arange(num_seeds, num_seeds * 2),
)
torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_link_collator
)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
minibatch_converter,
item_sampler,
negative_ratio,
gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
graph,
......@@ -184,25 +160,18 @@ def test_NegativeSampler_Hetero_Data(format):
itemset = gb.ItemSetDict(
{
"n1:e1:n2": gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1]),
torch.LongTensor([0, 2, 0, 1]),
)
torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
names="node_pairs",
),
"n2:e2:n1": gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1, 2, 2]),
torch.LongTensor([0, 1, 1, 0, 0, 1]),
)
torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
names="node_pairs",
),
}
)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
negative_dp = gb.UniformNegativeSampler(
minibatch_converter, 1, format, graph
)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
negative_dp = gb.UniformNegativeSampler(item_sampler, 1, format, graph)
for neg in negative_dp:
print(neg)
assert len(list(negative_dp)) == 5
......@@ -15,14 +15,11 @@ def test_FeatureFetcher_homo():
features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10))
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_node_collator
)
sampler_dp = gb.NeighborSampler(minibatch_converter, graph, fanouts)
sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"])
assert len(list(fetcher_dp)) == 5
......@@ -99,17 +96,14 @@ def test_FeatureFetcher_hetero():
itemset = gb.ItemSetDict(
{
"n1": gb.ItemSet(torch.LongTensor([0, 1])),
"n2": gb.ItemSet(torch.LongTensor([0, 1, 2])),
"n1": gb.ItemSet(torch.LongTensor([0, 1]), names="seed_nodes"),
"n2": gb.ItemSet(torch.LongTensor([0, 1, 2]), names="seed_nodes"),
}
)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_node_collator
)
sampler_dp = gb.NeighborSampler(minibatch_converter, graph, fanouts)
sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
)
......
......@@ -13,7 +13,7 @@ from torchdata.datapipes.iter import Mapper
def test_DataLoader():
N = 40
B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N))
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
graph = gb_test_utils.rand_csc_graph(200, 0.15)
features = {}
keys = [("node", None, "a"), ("node", None, "b")]
......@@ -22,11 +22,8 @@ def test_DataLoader():
feature_store = dgl.graphbolt.BasicFeatureStore(features)
item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_node_collator
)
subgraph_sampler = dgl.graphbolt.NeighborSampler(
minibatch_converter,
item_sampler,
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
)
......
......@@ -10,7 +10,7 @@ from torchdata.datapipes.iter import Mapper
def test_DataLoader():
N = 32
B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N))
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
graph = gb_test_utils.rand_csc_graph(200, 0.15)
features = {}
......@@ -20,11 +20,8 @@ def test_DataLoader():
feature_store = dgl.graphbolt.BasicFeatureStore(features)
item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_node_collator
)
subgraph_sampler = dgl.graphbolt.NeighborSampler(
minibatch_converter,
item_sampler,
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
)
......
......@@ -9,15 +9,12 @@ from torchdata.datapipes.iter import Mapper
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(torch.arange(10))
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_node_collator
)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
sampler_dp = Sampler(minibatch_converter, graph, fanouts)
sampler_dp = Sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 5
......@@ -29,20 +26,12 @@ def to_link_batch(data):
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(
(
torch.arange(0, 10),
torch.arange(10, 20),
)
)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(minibatch_converter, graph, fanouts)
neighbor_dp = Sampler(item_sampler, graph, fanouts)
assert len(list(neighbor_dp)) == 5
......@@ -58,21 +47,11 @@ def test_SubgraphSampler_Link(labor):
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_With_Negative(format, labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(
(
torch.arange(0, 10),
torch.arange(10, 20),
)
)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
negative_dp = gb.UniformNegativeSampler(
minibatch_converter, 1, format, graph
)
negative_dp = gb.UniformNegativeSampler(item_sampler, 1, format, graph)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5
......@@ -106,28 +85,21 @@ def test_SubgraphSampler_Link_Hetero(labor):
itemset = gb.ItemSetDict(
{
"n1:e1:n2": gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1]),
torch.LongTensor([0, 2, 0, 1]),
)
torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
names="node_pairs",
),
"n2:e2:n1": gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1, 2, 2]),
torch.LongTensor([0, 1, 1, 0, 0, 1]),
)
torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
names="node_pairs",
),
}
)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(minibatch_converter, graph, fanouts)
neighbor_dp = Sampler(item_sampler, graph, fanouts)
assert len(list(neighbor_dp)) == 5
......@@ -146,29 +118,20 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
itemset = gb.ItemSetDict(
{
"n1:e1:n2": gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1]),
torch.LongTensor([0, 2, 0, 1]),
)
torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
names="node_pairs",
),
"n2:e2:n1": gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1, 2, 2]),
torch.LongTensor([0, 1, 1, 0, 0, 1]),
)
torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
names="node_pairs",
),
}
)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
negative_dp = gb.UniformNegativeSampler(
minibatch_converter, 1, format, graph
)
negative_dp = gb.UniformNegativeSampler(item_sampler, 1, format, graph)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment