Unverified Commit 240e28a2 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] rename MinibatchSampler as ItemSampler (#6255)

parent fc366945
...@@ -14,7 +14,7 @@ from .feature_fetcher import * ...@@ -14,7 +14,7 @@ from .feature_fetcher import *
from .feature_store import * from .feature_store import *
from .impl import * from .impl import *
from .itemset import * from .itemset import *
from .minibatch_sampler import * from .item_sampler import *
from .negative_sampler import * from .negative_sampler import *
from .sampled_subgraph import * from .sampled_subgraph import *
from .subgraph_sampler import * from .subgraph_sampler import *
......
...@@ -5,7 +5,7 @@ import torchdata.dataloader2.graph as dp_utils ...@@ -5,7 +5,7 @@ import torchdata.dataloader2.graph as dp_utils
import torchdata.datapipes as dp import torchdata.datapipes as dp
from .feature_fetcher import FeatureFetcher from .feature_fetcher import FeatureFetcher
from .minibatch_sampler import MinibatchSampler from .item_sampler import ItemSampler
from .utils import datapipe_graph_to_adjlist from .utils import datapipe_graph_to_adjlist
...@@ -26,7 +26,7 @@ class SingleProcessDataLoader(torch.utils.data.DataLoader): ...@@ -26,7 +26,7 @@ class SingleProcessDataLoader(torch.utils.data.DataLoader):
# dataloader as-is. # dataloader as-is.
# #
# The exception is that batch_size should be None, since we already # The exception is that batch_size should be None, since we already
# have minibatch sampling and collating in MinibatchSampler. # have minibatch sampling and collating in ItemSampler.
def __init__(self, datapipe): def __init__(self, datapipe):
super().__init__(datapipe, batch_size=None, num_workers=0) super().__init__(datapipe, batch_size=None, num_workers=0)
...@@ -77,7 +77,7 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader): ...@@ -77,7 +77,7 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader):
def __init__(self, datapipe, num_workers=0): def __init__(self, datapipe, num_workers=0):
# Multiprocessing requires two modifications to the datapipe: # Multiprocessing requires two modifications to the datapipe:
# #
# 1. Insert a stage after MinibatchSampler to distribute the # 1. Insert a stage after ItemSampler to distribute the
# minibatches evenly across processes. # minibatches evenly across processes.
# 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe # 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe
# of the FeatureFetcher with a multiprocessing PyTorch DataLoader. # of the FeatureFetcher with a multiprocessing PyTorch DataLoader.
...@@ -88,16 +88,16 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader): ...@@ -88,16 +88,16 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader):
# (1) Insert minibatch distribution. # (1) Insert minibatch distribution.
# TODO(BarclayII): Currently I'm using sharding_filter() as a # TODO(BarclayII): Currently I'm using sharding_filter() as a
# concept demonstration. Later on minibatch distribution should be # concept demonstration. Later on minibatch distribution should be
# merged into MinibatchSampler to maximize efficiency. # merged into ItemSampler to maximize efficiency.
minibatch_samplers = dp_utils.find_dps( item_samplers = dp_utils.find_dps(
datapipe_graph, datapipe_graph,
MinibatchSampler, ItemSampler,
) )
for minibatch_sampler in minibatch_samplers: for item_sampler in item_samplers:
datapipe_graph = dp_utils.replace_dp( datapipe_graph = dp_utils.replace_dp(
datapipe_graph, datapipe_graph,
minibatch_sampler, item_sampler,
minibatch_sampler.sharding_filter(), item_sampler.sharding_filter(),
) )
# (2) Cut datapipe at FeatureFetcher and wrap. # (2) Cut datapipe at FeatureFetcher and wrap.
......
...@@ -64,10 +64,10 @@ class NeighborSampler(SubgraphSampler): ...@@ -64,10 +64,10 @@ class NeighborSampler(SubgraphSampler):
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT >>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2])) >>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs) >>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler( >>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1, ...item_set, batch_size=1,
...) ...)
>>> data_block_converter = Mapper(minibatch_sampler, to_link_block) >>> data_block_converter = Mapper(item_sampler, to_link_block)
>>> neg_sampler = gb.UniformNegativeSampler( >>> neg_sampler = gb.UniformNegativeSampler(
...data_block_converter, 2, data_format, graph) ...data_block_converter, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]), >>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...@@ -175,10 +175,10 @@ class LayerNeighborSampler(NeighborSampler): ...@@ -175,10 +175,10 @@ class LayerNeighborSampler(NeighborSampler):
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT >>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2])) >>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs) >>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler( >>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1, ...item_set, batch_size=1,
...) ...)
>>> data_block_converter = Mapper(minibatch_sampler, to_link_block) >>> data_block_converter = Mapper(item_sampler, to_link_block)
>>> neg_sampler = gb.UniformNegativeSampler( >>> neg_sampler = gb.UniformNegativeSampler(
...data_block_converter, 2, data_format, graph) ...data_block_converter, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]), >>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
......
...@@ -44,11 +44,11 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -44,11 +44,11 @@ class UniformNegativeSampler(NegativeSampler):
>>> output_format = gb.LinkPredictionEdgeFormat.INDEPENDENT >>> output_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2])) >>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs) >>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler( >>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1, ...item_set, batch_size=1,
...) ...)
>>> neg_sampler = gb.UniformNegativeSampler( >>> neg_sampler = gb.UniformNegativeSampler(
...minibatch_sampler, 2, output_format, graph) ...item_sampler, 2, output_format, graph)
>>> for data in neg_sampler: >>> for data in neg_sampler:
... print(data) ... print(data)
... ...
...@@ -62,11 +62,11 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -62,11 +62,11 @@ class UniformNegativeSampler(NegativeSampler):
>>> output_format = gb.LinkPredictionEdgeFormat.CONDITIONED >>> output_format = gb.LinkPredictionEdgeFormat.CONDITIONED
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2])) >>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs) >>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler( >>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1, ...item_set, batch_size=1,
...) ...)
>>> neg_sampler = gb.UniformNegativeSampler( >>> neg_sampler = gb.UniformNegativeSampler(
...minibatch_sampler, 2, output_format, graph) ...item_sampler, 2, output_format, graph)
>>> for data in neg_sampler: >>> for data in neg_sampler:
... print(data) ... print(data)
... ...
......
"""Minibatch Sampler""" """Item Sampler"""
from collections.abc import Mapping from collections.abc import Mapping
from functools import partial from functools import partial
...@@ -11,17 +11,17 @@ from ..batch import batch as dgl_batch ...@@ -11,17 +11,17 @@ from ..batch import batch as dgl_batch
from ..heterograph import DGLGraph from ..heterograph import DGLGraph
from .itemset import ItemSet, ItemSetDict from .itemset import ItemSet, ItemSetDict
__all__ = ["MinibatchSampler"] __all__ = ["ItemSampler"]
class MinibatchSampler(IterDataPipe): class ItemSampler(IterDataPipe):
"""Minibatch Sampler. """Item Sampler.
Creates mini-batches of data which could be node/edge IDs, node pairs with Creates item subset of data which could be node/edge IDs, node pairs with
or without labels, head/tail/negative_tails, DGLGraphs and heterogeneous or without labels, head/tail/negative_tails, DGLGraphs and heterogeneous
counterparts. counterparts.
Note: This class `MinibatchSampler` is not decorated with Note: This class `ItemSampler` is not decorated with
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it `torchdata.datapipes.functional_datapipe` on purpose. This indicates it
does not support function-like call. But any iterable datapipes from does not support function-like call. But any iterable datapipes from
`torchdata` can be further appended. `torchdata` can be further appended.
...@@ -29,7 +29,7 @@ class MinibatchSampler(IterDataPipe): ...@@ -29,7 +29,7 @@ class MinibatchSampler(IterDataPipe):
Parameters Parameters
---------- ----------
item_set : ItemSet or ItemSetDict item_set : ItemSet or ItemSetDict
Data to be sampled for mini-batches. Data to be sampled.
batch_size : int batch_size : int
The size of each batch. The size of each batch.
drop_last : bool drop_last : bool
...@@ -43,18 +43,18 @@ class MinibatchSampler(IterDataPipe): ...@@ -43,18 +43,18 @@ class MinibatchSampler(IterDataPipe):
>>> import torch >>> import torch
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10)) >>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> minibatch_sampler = gb.MinibatchSampler( >>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False ... item_set, batch_size=4, shuffle=True, drop_last=False
... ) ... )
>>> list(minibatch_sampler) >>> list(item_sampler)
[tensor([1, 2, 5, 7]), tensor([3, 0, 9, 4]), tensor([6, 8])] [tensor([1, 2, 5, 7]), tensor([3, 0, 9, 4]), tensor([6, 8])]
2. Node pairs. 2. Node pairs.
>>> item_set = gb.ItemSet((torch.arange(0, 10), torch.arange(10, 20))) >>> item_set = gb.ItemSet((torch.arange(0, 10), torch.arange(10, 20)))
>>> minibatch_sampler = gb.MinibatchSampler( >>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False ... item_set, batch_size=4, shuffle=True, drop_last=False
... ) ... )
>>> list(minibatch_sampler) >>> list(item_sampler)
[[tensor([9, 8, 3, 1]), tensor([19, 18, 13, 11])], [tensor([2, 5, 7, 4]), [[tensor([9, 8, 3, 1]), tensor([19, 18, 13, 11])], [tensor([2, 5, 7, 4]),
tensor([12, 15, 17, 14])], [tensor([0, 6]), tensor([10, 16])] tensor([12, 15, 17, 14])], [tensor([0, 6]), tensor([10, 16])]
...@@ -62,8 +62,8 @@ class MinibatchSampler(IterDataPipe): ...@@ -62,8 +62,8 @@ class MinibatchSampler(IterDataPipe):
>>> item_set = gb.ItemSet( >>> item_set = gb.ItemSet(
... (torch.arange(0, 5), torch.arange(5, 10), torch.arange(10, 15)) ... (torch.arange(0, 5), torch.arange(5, 10), torch.arange(10, 15))
... ) ... )
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 3) >>> item_sampler = gb.ItemSampler(item_set, 3)
>>> list(minibatch_sampler) >>> list(item_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]), tensor([10, 11, 12])], [[tensor([0, 1, 2]), tensor([5, 6, 7]), tensor([10, 11, 12])],
[tensor([3, 4]), tensor([8, 9]), tensor([13, 14])]] [tensor([3, 4]), tensor([8, 9]), tensor([13, 14])]]
...@@ -72,8 +72,8 @@ class MinibatchSampler(IterDataPipe): ...@@ -72,8 +72,8 @@ class MinibatchSampler(IterDataPipe):
>>> tails = torch.arange(5, 10) >>> tails = torch.arange(5, 10)
>>> negative_tails = torch.stack((heads + 1, heads + 2), dim=-1) >>> negative_tails = torch.stack((heads + 1, heads + 2), dim=-1)
>>> item_set = gb.ItemSet((heads, tails, negative_tails)) >>> item_set = gb.ItemSet((heads, tails, negative_tails))
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 3) >>> item_sampler = gb.ItemSampler(item_set, 3)
>>> list(minibatch_sampler) >>> list(item_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]), [[tensor([0, 1, 2]), tensor([5, 6, 7]),
tensor([[1, 2], [2, 3], [3, 4]])], tensor([[1, 2], [2, 3], [3, 4]])],
[tensor([3, 4]), tensor([8, 9]), tensor([[4, 5], [5, 6]])]] [tensor([3, 4]), tensor([8, 9]), tensor([[4, 5], [5, 6]])]]
...@@ -82,8 +82,8 @@ class MinibatchSampler(IterDataPipe): ...@@ -82,8 +82,8 @@ class MinibatchSampler(IterDataPipe):
>>> import dgl >>> import dgl
>>> graphs = [ dgl.rand_graph(10, 20) for _ in range(5) ] >>> graphs = [ dgl.rand_graph(10, 20) for _ in range(5) ]
>>> item_set = gb.ItemSet(graphs) >>> item_set = gb.ItemSet(graphs)
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 3) >>> item_sampler = gb.ItemSampler(item_set, 3)
>>> list(minibatch_sampler) >>> list(item_sampler)
[Graph(num_nodes=30, num_edges=60, [Graph(num_nodes=30, num_edges=60,
ndata_schemes={} ndata_schemes={}
edata_schemes={}), edata_schemes={}),
...@@ -94,7 +94,7 @@ class MinibatchSampler(IterDataPipe): ...@@ -94,7 +94,7 @@ class MinibatchSampler(IterDataPipe):
6. Further process batches with other datapipes such as 6. Further process batches with other datapipes such as
`torchdata.datapipes.iter.Mapper`. `torchdata.datapipes.iter.Mapper`.
>>> item_set = gb.ItemSet(torch.arange(0, 10)) >>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> data_pipe = gb.MinibatchSampler(item_set, 4) >>> data_pipe = gb.ItemSampler(item_set, 4)
>>> def add_one(batch): >>> def add_one(batch):
... return batch + 1 ... return batch + 1
>>> data_pipe = data_pipe.map(add_one) >>> data_pipe = data_pipe.map(add_one)
...@@ -107,8 +107,8 @@ class MinibatchSampler(IterDataPipe): ...@@ -107,8 +107,8 @@ class MinibatchSampler(IterDataPipe):
... "item": gb.ItemSet(torch.arange(0, 6)), ... "item": gb.ItemSet(torch.arange(0, 6)),
... } ... }
>>> item_set = gb.ItemSetDict(ids) >>> item_set = gb.ItemSetDict(ids)
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4) >>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(minibatch_sampler) >>> list(item_sampler)
[{'user': tensor([0, 1, 2, 3])}, [{'user': tensor([0, 1, 2, 3])},
{'item': tensor([0, 1, 2]), 'user': tensor([4])}, {'item': tensor([0, 1, 2]), 'user': tensor([4])},
{'item': tensor([3, 4, 5])}] {'item': tensor([3, 4, 5])}]
...@@ -120,8 +120,8 @@ class MinibatchSampler(IterDataPipe): ...@@ -120,8 +120,8 @@ class MinibatchSampler(IterDataPipe):
... "user:like:item": gb.ItemSet(node_pairs_like), ... "user:like:item": gb.ItemSet(node_pairs_like),
... "user:follow:user": gb.ItemSet(node_pairs_follow), ... "user:follow:user": gb.ItemSet(node_pairs_follow),
... }) ... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4) >>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(minibatch_sampler) >>> list(item_sampler)
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]}, [{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{"user:like:item": [tensor([4]), tensor([4])], {"user:like:item": [tensor([4]), tensor([4])],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8])]}, "user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8])]},
...@@ -136,8 +136,8 @@ class MinibatchSampler(IterDataPipe): ...@@ -136,8 +136,8 @@ class MinibatchSampler(IterDataPipe):
... "user:like:item": gb.ItemSet(like), ... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow), ... "user:follow:user": gb.ItemSet(follow),
... }) ... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4) >>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(minibatch_sampler) >>> list(item_sampler)
[{"user:like:item": [{"user:like:item":
[tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]}, [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{"user:like:item": [tensor([4]), tensor([4]), tensor([4])], {"user:like:item": [tensor([4]), tensor([4]), tensor([4])],
...@@ -157,8 +157,8 @@ class MinibatchSampler(IterDataPipe): ...@@ -157,8 +157,8 @@ class MinibatchSampler(IterDataPipe):
... "user:like:item": gb.ItemSet(like), ... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow), ... "user:follow:user": gb.ItemSet(follow),
... }) ... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4) >>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(minibatch_sampler) >>> list(item_sampler)
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), [{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]),
tensor([[ 5, 6], [ 7, 8], [ 9, 10], [11, 12]])]}, tensor([[ 5, 6], [ 7, 8], [ 9, 10], [11, 12]])]},
{"user:like:item": [tensor([4]), tensor([4]), tensor([[13, 14]])], {"user:like:item": [tensor([4]), tensor([4]), tensor([[13, 14]])],
......
...@@ -21,8 +21,8 @@ def test_NegativeSampler_Independent_Format(negative_ratio): ...@@ -21,8 +21,8 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
) )
) )
batch_size = 10 batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(minibatch_sampler, to_data_block) data_block_converter = Mapper(item_sampler, to_data_block)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
data_block_converter, data_block_converter,
...@@ -54,8 +54,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio): ...@@ -54,8 +54,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
) )
) )
batch_size = 10 batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(minibatch_sampler, to_data_block) data_block_converter = Mapper(item_sampler, to_data_block)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
data_block_converter, data_block_converter,
...@@ -90,8 +90,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio): ...@@ -90,8 +90,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
) )
) )
batch_size = 10 batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(minibatch_sampler, to_data_block) data_block_converter = Mapper(item_sampler, to_data_block)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
data_block_converter, data_block_converter,
...@@ -124,8 +124,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio): ...@@ -124,8 +124,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
) )
) )
batch_size = 10 batch_size = 10
minibatch_sampler = gb.MinibatchSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(minibatch_sampler, to_data_block) data_block_converter = Mapper(item_sampler, to_data_block)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
data_block_converter, data_block_converter,
...@@ -199,8 +199,8 @@ def test_NegativeSampler_Hetero_Data(format): ...@@ -199,8 +199,8 @@ def test_NegativeSampler_Hetero_Data(format):
} }
) )
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
data_block_converter = Mapper(minibatch_dp, to_link_block) data_block_converter = Mapper(item_sampler_dp, to_link_block)
negative_dp = gb.UniformNegativeSampler( negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph data_block_converter, 1, format, graph
) )
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test") @unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo(): def test_CopyTo():
dp = gb.MinibatchSampler(torch.randn(20), 4) dp = gb.ItemSampler(torch.randn(20), 4)
dp = gb.CopyTo(dp, "cuda") dp = gb.CopyTo(dp, "cuda")
for data in dp: for data in dp:
......
...@@ -16,10 +16,10 @@ def test_FeatureFetcher_homo(): ...@@ -16,10 +16,10 @@ def test_FeatureFetcher_homo():
feature_store = gb.BasicFeatureStore(features) feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10)) itemset = gb.ItemSet(torch.arange(10))
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, gb_test_utils.to_node_block) data_block_converter = Mapper(item_sampler_dp, gb_test_utils.to_node_block)
sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts) sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"]) fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"])
...@@ -52,8 +52,8 @@ def test_FeatureFetcher_with_edges_homo(): ...@@ -52,8 +52,8 @@ def test_FeatureFetcher_with_edges_homo():
feature_store = gb.BasicFeatureStore(features) feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10)) itemset = gb.ItemSet(torch.arange(10))
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(minibatch_dp, add_node_and_edge_ids) converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"]) fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"])
assert len(list(fetcher_dp)) == 5 assert len(list(fetcher_dp)) == 5
...@@ -103,10 +103,10 @@ def test_FeatureFetcher_hetero(): ...@@ -103,10 +103,10 @@ def test_FeatureFetcher_hetero():
"n2": gb.ItemSet(torch.LongTensor([0, 1, 2])), "n2": gb.ItemSet(torch.LongTensor([0, 1, 2])),
} }
) )
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, gb_test_utils.to_node_block) data_block_converter = Mapper(item_sampler_dp, gb_test_utils.to_node_block)
sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts) sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
fetcher_dp = gb.FeatureFetcher( fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]} sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
...@@ -148,8 +148,8 @@ def test_FeatureFetcher_with_edges_hetero(): ...@@ -148,8 +148,8 @@ def test_FeatureFetcher_with_edges_hetero():
"n1": gb.ItemSet(torch.randint(0, 20, (10,))), "n1": gb.ItemSet(torch.randint(0, 20, (10,))),
} }
) )
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(minibatch_dp, add_node_and_edge_ids) converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
fetcher_dp = gb.FeatureFetcher( fetcher_dp = gb.FeatureFetcher(
converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]} converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]}
) )
......
...@@ -12,11 +12,11 @@ def test_ItemSet_node_ids(batch_size, shuffle, drop_last): ...@@ -12,11 +12,11 @@ def test_ItemSet_node_ids(batch_size, shuffle, drop_last):
# Node IDs. # Node IDs.
num_ids = 103 num_ids = 103
item_set = gb.ItemSet(torch.arange(0, num_ids)) item_set = gb.ItemSet(torch.arange(0, num_ids))
minibatch_sampler = gb.MinibatchSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
minibatch_ids = [] minibatch_ids = []
for i, minibatch in enumerate(minibatch_sampler): for i, minibatch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
assert len(minibatch) == batch_size assert len(minibatch) == batch_size
...@@ -43,12 +43,12 @@ def test_ItemSet_graphs(batch_size, shuffle, drop_last): ...@@ -43,12 +43,12 @@ def test_ItemSet_graphs(batch_size, shuffle, drop_last):
for i in range(num_graphs) for i in range(num_graphs)
] ]
item_set = gb.ItemSet(graphs) item_set = gb.ItemSet(graphs)
minibatch_sampler = gb.MinibatchSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
minibatch_num_nodes = [] minibatch_num_nodes = []
minibatch_num_edges = [] minibatch_num_edges = []
for i, minibatch in enumerate(minibatch_sampler): for i, minibatch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= num_graphs is_last = (i + 1) * batch_size >= num_graphs
if not is_last or num_graphs % batch_size == 0: if not is_last or num_graphs % batch_size == 0:
assert minibatch.batch_size == batch_size assert minibatch.batch_size == batch_size
...@@ -79,12 +79,12 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last): ...@@ -79,12 +79,12 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
num_ids = 103 num_ids = 103
node_pairs = (torch.arange(0, num_ids), torch.arange(num_ids, num_ids * 2)) node_pairs = (torch.arange(0, num_ids), torch.arange(num_ids, num_ids * 2))
item_set = gb.ItemSet(node_pairs) item_set = gb.ItemSet(node_pairs)
minibatch_sampler = gb.MinibatchSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
src_ids = [] src_ids = []
dst_ids = [] dst_ids = []
for i, (src, dst) in enumerate(minibatch_sampler): for i, (src, dst) in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -115,13 +115,13 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -115,13 +115,13 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
node_pairs = (torch.arange(0, num_ids), torch.arange(num_ids, num_ids * 2)) node_pairs = (torch.arange(0, num_ids), torch.arange(num_ids, num_ids * 2))
labels = torch.arange(0, num_ids) labels = torch.arange(0, num_ids)
item_set = gb.ItemSet((node_pairs[0], node_pairs[1], labels)) item_set = gb.ItemSet((node_pairs[0], node_pairs[1], labels))
minibatch_sampler = gb.MinibatchSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
src_ids = [] src_ids = []
dst_ids = [] dst_ids = []
labels = [] labels = []
for i, (src, dst, label) in enumerate(minibatch_sampler): for i, (src, dst, label) in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -163,13 +163,13 @@ def test_ItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last): ...@@ -163,13 +163,13 @@ def test_ItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last):
assert heads[i] == head assert heads[i] == head
assert tails[i] == tail assert tails[i] == tail
assert torch.equal(neg_tails[i], negs) assert torch.equal(neg_tails[i], negs)
minibatch_sampler = gb.MinibatchSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
head_ids = [] head_ids = []
tail_ids = [] tail_ids = []
negs_ids = [] negs_ids = []
for i, (head, tail, negs) in enumerate(minibatch_sampler): for i, (head, tail, negs) in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -204,7 +204,7 @@ def test_append_with_other_datapipes(): ...@@ -204,7 +204,7 @@ def test_append_with_other_datapipes():
num_ids = 100 num_ids = 100
batch_size = 4 batch_size = 4
item_set = gb.ItemSet(torch.arange(0, num_ids)) item_set = gb.ItemSet(torch.arange(0, num_ids))
data_pipe = gb.MinibatchSampler(item_set, batch_size) data_pipe = gb.ItemSampler(item_set, batch_size)
# torchdata.datapipes.iter.Enumerator # torchdata.datapipes.iter.Enumerator
data_pipe = data_pipe.enumerate() data_pipe = data_pipe.enumerate()
for i, (idx, data) in enumerate(data_pipe): for i, (idx, data) in enumerate(data_pipe):
...@@ -226,11 +226,11 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last): ...@@ -226,11 +226,11 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last):
for key, value in ids.items(): for key, value in ids.items():
chained_ids += [(key, v) for v in value] chained_ids += [(key, v) for v in value]
item_set = gb.ItemSetDict(ids) item_set = gb.ItemSetDict(ids)
minibatch_sampler = gb.MinibatchSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
minibatch_ids = [] minibatch_ids = []
for i, batch in enumerate(minibatch_sampler): for i, batch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -270,12 +270,12 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last): ...@@ -270,12 +270,12 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
"user:follow:user": gb.ItemSet(node_pairs_1), "user:follow:user": gb.ItemSet(node_pairs_1),
} }
item_set = gb.ItemSetDict(node_pairs_dict) item_set = gb.ItemSetDict(node_pairs_dict)
minibatch_sampler = gb.MinibatchSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
src_ids = [] src_ids = []
dst_ids = [] dst_ids = []
for i, batch in enumerate(minibatch_sampler): for i, batch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= total_ids is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0: if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -327,13 +327,13 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -327,13 +327,13 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
), ),
} }
item_set = gb.ItemSetDict(node_pairs_dict) item_set = gb.ItemSetDict(node_pairs_dict)
minibatch_sampler = gb.MinibatchSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
src_ids = [] src_ids = []
dst_ids = [] dst_ids = []
labels = [] labels = []
for i, batch in enumerate(minibatch_sampler): for i, batch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= total_ids is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0: if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -384,13 +384,13 @@ def test_ItemSetDict_head_tail_neg_tails(batch_size, shuffle, drop_last): ...@@ -384,13 +384,13 @@ def test_ItemSetDict_head_tail_neg_tails(batch_size, shuffle, drop_last):
"user:follow:user": gb.ItemSet((heads, tails, neg_tails)), "user:follow:user": gb.ItemSet((heads, tails, neg_tails)),
} }
item_set = gb.ItemSetDict(data_dict) item_set = gb.ItemSetDict(data_dict)
minibatch_sampler = gb.MinibatchSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
head_ids = [] head_ids = []
tail_ids = [] tail_ids = []
negs_ids = [] negs_ids = []
for i, batch in enumerate(minibatch_sampler): for i, batch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= total_ids is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0: if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
......
...@@ -22,8 +22,8 @@ def test_DataLoader(): ...@@ -22,8 +22,8 @@ def test_DataLoader():
features[keys[1]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4)) features[keys[1]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
feature_store = dgl.graphbolt.BasicFeatureStore(features) feature_store = dgl.graphbolt.BasicFeatureStore(features)
minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B) item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
block_converter = Mapper(minibatch_sampler, gb_test_utils.to_node_block) block_converter = Mapper(item_sampler, gb_test_utils.to_node_block)
subgraph_sampler = dgl.graphbolt.NeighborSampler( subgraph_sampler = dgl.graphbolt.NeighborSampler(
block_converter, block_converter,
graph, graph,
......
...@@ -24,8 +24,8 @@ def test_DataLoader(): ...@@ -24,8 +24,8 @@ def test_DataLoader():
features[keys[1]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4)) features[keys[1]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
feature_store = dgl.graphbolt.BasicFeatureStore(features) feature_store = dgl.graphbolt.BasicFeatureStore(features)
minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B) item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
block_converter = Mapper(minibatch_sampler, to_node_block) block_converter = Mapper(item_sampler, to_node_block)
subgraph_sampler = dgl.graphbolt.NeighborSampler( subgraph_sampler = dgl.graphbolt.NeighborSampler(
block_converter, block_converter,
graph, graph,
......
...@@ -15,10 +15,10 @@ def to_node_block(data): ...@@ -15,10 +15,10 @@ def to_node_block(data):
def test_SubgraphSampler_Node(labor): def test_SubgraphSampler_Node(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(torch.arange(10)) itemset = gb.ItemSet(torch.arange(10))
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_node_block) data_block_converter = Mapper(item_sampler_dp, to_node_block)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
sampler_dp = Sampler(data_block_converter, graph, fanouts) sampler_dp = Sampler(data_block_converter, graph, fanouts)
assert len(list(sampler_dp)) == 5 assert len(list(sampler_dp)) == 5
...@@ -38,10 +38,10 @@ def test_SubgraphSampler_Link(labor): ...@@ -38,10 +38,10 @@ def test_SubgraphSampler_Link(labor):
torch.arange(10, 20), torch.arange(10, 20),
) )
) )
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_link_block) data_block_converter = Mapper(item_sampler_dp, to_link_block)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(data_block_converter, graph, fanouts) neighbor_dp = Sampler(data_block_converter, graph, fanouts)
assert len(list(neighbor_dp)) == 5 assert len(list(neighbor_dp)) == 5
...@@ -65,10 +65,10 @@ def test_SubgraphSampler_Link_With_Negative(format, labor): ...@@ -65,10 +65,10 @@ def test_SubgraphSampler_Link_With_Negative(format, labor):
torch.arange(10, 20), torch.arange(10, 20),
) )
) )
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_link_block) data_block_converter = Mapper(item_sampler_dp, to_link_block)
negative_dp = gb.UniformNegativeSampler( negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph data_block_converter, 1, format, graph
) )
...@@ -119,10 +119,10 @@ def test_SubgraphSampler_Link_Hetero(labor): ...@@ -119,10 +119,10 @@ def test_SubgraphSampler_Link_Hetero(labor):
} }
) )
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_link_block) data_block_converter = Mapper(item_sampler_dp, to_link_block)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(data_block_converter, graph, fanouts) neighbor_dp = Sampler(data_block_converter, graph, fanouts)
assert len(list(neighbor_dp)) == 5 assert len(list(neighbor_dp)) == 5
...@@ -157,10 +157,10 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor): ...@@ -157,10 +157,10 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
} }
) )
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_link_block) data_block_converter = Mapper(item_sampler_dp, to_link_block)
negative_dp = gb.UniformNegativeSampler( negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph data_block_converter, 1, format, graph
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment