Unverified Commit 240e28a2 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] rename MinibatchSampler as ItemSampler (#6255)

parent fc366945
......@@ -14,7 +14,7 @@ from .feature_fetcher import *
from .feature_store import *
from .impl import *
from .itemset import *
from .minibatch_sampler import *
from .item_sampler import *
from .negative_sampler import *
from .sampled_subgraph import *
from .subgraph_sampler import *
......
......@@ -5,7 +5,7 @@ import torchdata.dataloader2.graph as dp_utils
import torchdata.datapipes as dp
from .feature_fetcher import FeatureFetcher
from .minibatch_sampler import MinibatchSampler
from .item_sampler import ItemSampler
from .utils import datapipe_graph_to_adjlist
......@@ -26,7 +26,7 @@ class SingleProcessDataLoader(torch.utils.data.DataLoader):
# dataloader as-is.
#
# The exception is that batch_size should be None, since we already
# have minibatch sampling and collating in MinibatchSampler.
# have minibatch sampling and collating in ItemSampler.
def __init__(self, datapipe):
super().__init__(datapipe, batch_size=None, num_workers=0)
......@@ -77,7 +77,7 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader):
def __init__(self, datapipe, num_workers=0):
# Multiprocessing requires two modifications to the datapipe:
#
# 1. Insert a stage after MinibatchSampler to distribute the
# 1. Insert a stage after ItemSampler to distribute the
# minibatches evenly across processes.
# 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe
# of the FeatureFetcher with a multiprocessing PyTorch DataLoader.
......@@ -88,16 +88,16 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader):
# (1) Insert minibatch distribution.
# TODO(BarclayII): Currently I'm using sharding_filter() as a
# concept demonstration. Later on minibatch distribution should be
# merged into MinibatchSampler to maximize efficiency.
minibatch_samplers = dp_utils.find_dps(
# merged into ItemSampler to maximize efficiency.
item_samplers = dp_utils.find_dps(
datapipe_graph,
MinibatchSampler,
ItemSampler,
)
for minibatch_sampler in minibatch_samplers:
for item_sampler in item_samplers:
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
minibatch_sampler,
minibatch_sampler.sharding_filter(),
item_sampler,
item_sampler.sharding_filter(),
)
# (2) Cut datapipe at FeatureFetcher and wrap.
......
......@@ -64,10 +64,10 @@ class NeighborSampler(SubgraphSampler):
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler(
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> data_block_converter = Mapper(minibatch_sampler, to_link_block)
>>> data_block_converter = Mapper(item_sampler, to_link_block)
>>> neg_sampler = gb.UniformNegativeSampler(
...data_block_converter, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
......@@ -175,10 +175,10 @@ class LayerNeighborSampler(NeighborSampler):
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler(
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> data_block_converter = Mapper(minibatch_sampler, to_link_block)
>>> data_block_converter = Mapper(item_sampler, to_link_block)
>>> neg_sampler = gb.UniformNegativeSampler(
...data_block_converter, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
......
......@@ -44,11 +44,11 @@ class UniformNegativeSampler(NegativeSampler):
>>> output_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler(
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> neg_sampler = gb.UniformNegativeSampler(
...minibatch_sampler, 2, output_format, graph)
...item_sampler, 2, output_format, graph)
>>> for data in neg_sampler:
... print(data)
...
......@@ -62,11 +62,11 @@ class UniformNegativeSampler(NegativeSampler):
>>> output_format = gb.LinkPredictionEdgeFormat.CONDITIONED
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler(
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> neg_sampler = gb.UniformNegativeSampler(
...minibatch_sampler, 2, output_format, graph)
...item_sampler, 2, output_format, graph)
>>> for data in neg_sampler:
... print(data)
...
......
"""Minibatch Sampler"""
"""Item Sampler"""
from collections.abc import Mapping
from functools import partial
......@@ -11,17 +11,17 @@ from ..batch import batch as dgl_batch
from ..heterograph import DGLGraph
from .itemset import ItemSet, ItemSetDict
__all__ = ["MinibatchSampler"]
__all__ = ["ItemSampler"]
class MinibatchSampler(IterDataPipe):
"""Minibatch Sampler.
class ItemSampler(IterDataPipe):
"""Item Sampler.
Creates mini-batches of data which could be node/edge IDs, node pairs with
Creates item subset of data which could be node/edge IDs, node pairs with
or without labels, head/tail/negative_tails, DGLGraphs and heterogeneous
counterparts.
Note: This class `MinibatchSampler` is not decorated with
Note: This class `ItemSampler` is not decorated with
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it
does not support function-like call. But any iterable datapipes from
`torchdata` can be further appended.
......@@ -29,7 +29,7 @@ class MinibatchSampler(IterDataPipe):
Parameters
----------
item_set : ItemSet or ItemSetDict
Data to be sampled for mini-batches.
Data to be sampled.
batch_size : int
The size of each batch.
drop_last : bool
......@@ -43,18 +43,18 @@ class MinibatchSampler(IterDataPipe):
>>> import torch
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> minibatch_sampler = gb.MinibatchSampler(
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> list(minibatch_sampler)
>>> list(item_sampler)
[tensor([1, 2, 5, 7]), tensor([3, 0, 9, 4]), tensor([6, 8])]
2. Node pairs.
>>> item_set = gb.ItemSet((torch.arange(0, 10), torch.arange(10, 20)))
>>> minibatch_sampler = gb.MinibatchSampler(
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> list(minibatch_sampler)
>>> list(item_sampler)
[[tensor([9, 8, 3, 1]), tensor([19, 18, 13, 11])], [tensor([2, 5, 7, 4]),
tensor([12, 15, 17, 14])], [tensor([0, 6]), tensor([10, 16])]
......@@ -62,8 +62,8 @@ class MinibatchSampler(IterDataPipe):
>>> item_set = gb.ItemSet(
... (torch.arange(0, 5), torch.arange(5, 10), torch.arange(10, 15))
... )
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 3)
>>> list(minibatch_sampler)
>>> item_sampler = gb.ItemSampler(item_set, 3)
>>> list(item_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]), tensor([10, 11, 12])],
[tensor([3, 4]), tensor([8, 9]), tensor([13, 14])]]
......@@ -72,8 +72,8 @@ class MinibatchSampler(IterDataPipe):
>>> tails = torch.arange(5, 10)
>>> negative_tails = torch.stack((heads + 1, heads + 2), dim=-1)
>>> item_set = gb.ItemSet((heads, tails, negative_tails))
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 3)
>>> list(minibatch_sampler)
>>> item_sampler = gb.ItemSampler(item_set, 3)
>>> list(item_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]),
tensor([[1, 2], [2, 3], [3, 4]])],
[tensor([3, 4]), tensor([8, 9]), tensor([[4, 5], [5, 6]])]]
......@@ -82,8 +82,8 @@ class MinibatchSampler(IterDataPipe):
>>> import dgl
>>> graphs = [ dgl.rand_graph(10, 20) for _ in range(5) ]
>>> item_set = gb.ItemSet(graphs)
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 3)
>>> list(minibatch_sampler)
>>> item_sampler = gb.ItemSampler(item_set, 3)
>>> list(item_sampler)
[Graph(num_nodes=30, num_edges=60,
ndata_schemes={}
edata_schemes={}),
......@@ -94,7 +94,7 @@ class MinibatchSampler(IterDataPipe):
6. Further process batches with other datapipes such as
`torchdata.datapipes.iter.Mapper`.
>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> data_pipe = gb.MinibatchSampler(item_set, 4)
>>> data_pipe = gb.ItemSampler(item_set, 4)
>>> def add_one(batch):
... return batch + 1
>>> data_pipe = data_pipe.map(add_one)
......@@ -107,8 +107,8 @@ class MinibatchSampler(IterDataPipe):
... "item": gb.ItemSet(torch.arange(0, 6)),
... }
>>> item_set = gb.ItemSetDict(ids)
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{'user': tensor([0, 1, 2, 3])},
{'item': tensor([0, 1, 2]), 'user': tensor([4])},
{'item': tensor([3, 4, 5])}]
......@@ -120,8 +120,8 @@ class MinibatchSampler(IterDataPipe):
... "user:like:item": gb.ItemSet(node_pairs_like),
... "user:follow:user": gb.ItemSet(node_pairs_follow),
... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{"user:like:item": [tensor([4]), tensor([4])],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8])]},
......@@ -136,8 +136,8 @@ class MinibatchSampler(IterDataPipe):
... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow),
... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{"user:like:item":
[tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{"user:like:item": [tensor([4]), tensor([4]), tensor([4])],
......@@ -157,8 +157,8 @@ class MinibatchSampler(IterDataPipe):
... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow),
... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]),
tensor([[ 5, 6], [ 7, 8], [ 9, 10], [11, 12]])]},
{"user:like:item": [tensor([4]), tensor([4]), tensor([[13, 14]])],
......
......@@ -21,8 +21,8 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
)
)
batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(minibatch_sampler, to_data_block)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(item_sampler, to_data_block)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
data_block_converter,
......@@ -54,8 +54,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
)
)
batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(minibatch_sampler, to_data_block)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(item_sampler, to_data_block)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
data_block_converter,
......@@ -90,8 +90,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
)
)
batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(minibatch_sampler, to_data_block)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(item_sampler, to_data_block)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
data_block_converter,
......@@ -124,8 +124,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
)
)
batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(minibatch_sampler, to_data_block)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(item_sampler, to_data_block)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
data_block_converter,
......@@ -199,8 +199,8 @@ def test_NegativeSampler_Hetero_Data(format):
}
)
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
data_block_converter = Mapper(minibatch_dp, to_link_block)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
data_block_converter = Mapper(item_sampler_dp, to_link_block)
negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph
)
......
......@@ -10,7 +10,7 @@ import torch
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo():
dp = gb.MinibatchSampler(torch.randn(20), 4)
dp = gb.ItemSampler(torch.randn(20), 4)
dp = gb.CopyTo(dp, "cuda")
for data in dp:
......
......@@ -16,10 +16,10 @@ def test_FeatureFetcher_homo():
feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10))
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, gb_test_utils.to_node_block)
data_block_converter = Mapper(item_sampler_dp, gb_test_utils.to_node_block)
sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"])
......@@ -52,8 +52,8 @@ def test_FeatureFetcher_with_edges_homo():
feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10))
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
converter_dp = Mapper(minibatch_dp, add_node_and_edge_ids)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"])
assert len(list(fetcher_dp)) == 5
......@@ -103,10 +103,10 @@ def test_FeatureFetcher_hetero():
"n2": gb.ItemSet(torch.LongTensor([0, 1, 2])),
}
)
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, gb_test_utils.to_node_block)
data_block_converter = Mapper(item_sampler_dp, gb_test_utils.to_node_block)
sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
......@@ -148,8 +148,8 @@ def test_FeatureFetcher_with_edges_hetero():
"n1": gb.ItemSet(torch.randint(0, 20, (10,))),
}
)
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
converter_dp = Mapper(minibatch_dp, add_node_and_edge_ids)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
fetcher_dp = gb.FeatureFetcher(
converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]}
)
......
......@@ -12,11 +12,11 @@ def test_ItemSet_node_ids(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 103
item_set = gb.ItemSet(torch.arange(0, num_ids))
minibatch_sampler = gb.MinibatchSampler(
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
minibatch_ids = []
for i, minibatch in enumerate(minibatch_sampler):
for i, minibatch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
assert len(minibatch) == batch_size
......@@ -43,12 +43,12 @@ def test_ItemSet_graphs(batch_size, shuffle, drop_last):
for i in range(num_graphs)
]
item_set = gb.ItemSet(graphs)
minibatch_sampler = gb.MinibatchSampler(
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
minibatch_num_nodes = []
minibatch_num_edges = []
for i, minibatch in enumerate(minibatch_sampler):
for i, minibatch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= num_graphs
if not is_last or num_graphs % batch_size == 0:
assert minibatch.batch_size == batch_size
......@@ -79,12 +79,12 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
num_ids = 103
node_pairs = (torch.arange(0, num_ids), torch.arange(num_ids, num_ids * 2))
item_set = gb.ItemSet(node_pairs)
minibatch_sampler = gb.MinibatchSampler(
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
src_ids = []
dst_ids = []
for i, (src, dst) in enumerate(minibatch_sampler):
for i, (src, dst) in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -115,13 +115,13 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
node_pairs = (torch.arange(0, num_ids), torch.arange(num_ids, num_ids * 2))
labels = torch.arange(0, num_ids)
item_set = gb.ItemSet((node_pairs[0], node_pairs[1], labels))
minibatch_sampler = gb.MinibatchSampler(
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
src_ids = []
dst_ids = []
labels = []
for i, (src, dst, label) in enumerate(minibatch_sampler):
for i, (src, dst, label) in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -163,13 +163,13 @@ def test_ItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last):
assert heads[i] == head
assert tails[i] == tail
assert torch.equal(neg_tails[i], negs)
minibatch_sampler = gb.MinibatchSampler(
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
head_ids = []
tail_ids = []
negs_ids = []
for i, (head, tail, negs) in enumerate(minibatch_sampler):
for i, (head, tail, negs) in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -204,7 +204,7 @@ def test_append_with_other_datapipes():
num_ids = 100
batch_size = 4
item_set = gb.ItemSet(torch.arange(0, num_ids))
data_pipe = gb.MinibatchSampler(item_set, batch_size)
data_pipe = gb.ItemSampler(item_set, batch_size)
# torchdata.datapipes.iter.Enumerator
data_pipe = data_pipe.enumerate()
for i, (idx, data) in enumerate(data_pipe):
......@@ -226,11 +226,11 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last):
for key, value in ids.items():
chained_ids += [(key, v) for v in value]
item_set = gb.ItemSetDict(ids)
minibatch_sampler = gb.MinibatchSampler(
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
minibatch_ids = []
for i, batch in enumerate(minibatch_sampler):
for i, batch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -270,12 +270,12 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
"user:follow:user": gb.ItemSet(node_pairs_1),
}
item_set = gb.ItemSetDict(node_pairs_dict)
minibatch_sampler = gb.MinibatchSampler(
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
src_ids = []
dst_ids = []
for i, batch in enumerate(minibatch_sampler):
for i, batch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -327,13 +327,13 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
),
}
item_set = gb.ItemSetDict(node_pairs_dict)
minibatch_sampler = gb.MinibatchSampler(
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
src_ids = []
dst_ids = []
labels = []
for i, batch in enumerate(minibatch_sampler):
for i, batch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -384,13 +384,13 @@ def test_ItemSetDict_head_tail_neg_tails(batch_size, shuffle, drop_last):
"user:follow:user": gb.ItemSet((heads, tails, neg_tails)),
}
item_set = gb.ItemSetDict(data_dict)
minibatch_sampler = gb.MinibatchSampler(
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
head_ids = []
tail_ids = []
negs_ids = []
for i, batch in enumerate(minibatch_sampler):
for i, batch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size
......
......@@ -22,8 +22,8 @@ def test_DataLoader():
features[keys[1]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
feature_store = dgl.graphbolt.BasicFeatureStore(features)
minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B)
block_converter = Mapper(minibatch_sampler, gb_test_utils.to_node_block)
item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
block_converter = Mapper(item_sampler, gb_test_utils.to_node_block)
subgraph_sampler = dgl.graphbolt.NeighborSampler(
block_converter,
graph,
......
......@@ -24,8 +24,8 @@ def test_DataLoader():
features[keys[1]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
feature_store = dgl.graphbolt.BasicFeatureStore(features)
minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B)
block_converter = Mapper(minibatch_sampler, to_node_block)
item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
block_converter = Mapper(item_sampler, to_node_block)
subgraph_sampler = dgl.graphbolt.NeighborSampler(
block_converter,
graph,
......
......@@ -15,10 +15,10 @@ def to_node_block(data):
def test_SubgraphSampler_Node(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(torch.arange(10))
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_node_block)
data_block_converter = Mapper(item_sampler_dp, to_node_block)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
sampler_dp = Sampler(data_block_converter, graph, fanouts)
assert len(list(sampler_dp)) == 5
......@@ -38,10 +38,10 @@ def test_SubgraphSampler_Link(labor):
torch.arange(10, 20),
)
)
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_link_block)
data_block_converter = Mapper(item_sampler_dp, to_link_block)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(data_block_converter, graph, fanouts)
assert len(list(neighbor_dp)) == 5
......@@ -65,10 +65,10 @@ def test_SubgraphSampler_Link_With_Negative(format, labor):
torch.arange(10, 20),
)
)
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_link_block)
data_block_converter = Mapper(item_sampler_dp, to_link_block)
negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph
)
......@@ -119,10 +119,10 @@ def test_SubgraphSampler_Link_Hetero(labor):
}
)
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_link_block)
data_block_converter = Mapper(item_sampler_dp, to_link_block)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(data_block_converter, graph, fanouts)
assert len(list(neighbor_dp)) == 5
......@@ -157,10 +157,10 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
}
)
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_link_block)
data_block_converter = Mapper(item_sampler_dp, to_link_block)
negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment