Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
240e28a2
Unverified
Commit
240e28a2
authored
Sep 01, 2023
by
Rhett Ying
Committed by
GitHub
Sep 01, 2023
Browse files
[GraphBolt] rename MinibatchSampler as ItemSampler (#6255)
parent
fc366945
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
96 additions
and
96 deletions
+96
-96
python/dgl/graphbolt/__init__.py
python/dgl/graphbolt/__init__.py
+1
-1
python/dgl/graphbolt/dataloader.py
python/dgl/graphbolt/dataloader.py
+9
-9
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+4
-4
python/dgl/graphbolt/impl/uniform_negative_sampler.py
python/dgl/graphbolt/impl/uniform_negative_sampler.py
+4
-4
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+26
-26
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
+10
-10
tests/python/pytorch/graphbolt/test_base.py
tests/python/pytorch/graphbolt/test_base.py
+1
-1
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+8
-8
tests/python/pytorch/graphbolt/test_item_sampler.py
tests/python/pytorch/graphbolt/test_item_sampler.py
+19
-19
tests/python/pytorch/graphbolt/test_multi_process_dataloader.py
...python/pytorch/graphbolt/test_multi_process_dataloader.py
+2
-2
tests/python/pytorch/graphbolt/test_single_process_dataloader.py
...ython/pytorch/graphbolt/test_single_process_dataloader.py
+2
-2
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+10
-10
No files found.
python/dgl/graphbolt/__init__.py
View file @
240e28a2
...
@@ -14,7 +14,7 @@ from .feature_fetcher import *
...
@@ -14,7 +14,7 @@ from .feature_fetcher import *
from
.feature_store
import
*
from
.feature_store
import
*
from
.impl
import
*
from
.impl
import
*
from
.itemset
import
*
from
.itemset
import
*
from
.
minibatch
_sampler
import
*
from
.
item
_sampler
import
*
from
.negative_sampler
import
*
from
.negative_sampler
import
*
from
.sampled_subgraph
import
*
from
.sampled_subgraph
import
*
from
.subgraph_sampler
import
*
from
.subgraph_sampler
import
*
...
...
python/dgl/graphbolt/dataloader.py
View file @
240e28a2
...
@@ -5,7 +5,7 @@ import torchdata.dataloader2.graph as dp_utils
...
@@ -5,7 +5,7 @@ import torchdata.dataloader2.graph as dp_utils
import
torchdata.datapipes
as
dp
import
torchdata.datapipes
as
dp
from
.feature_fetcher
import
FeatureFetcher
from
.feature_fetcher
import
FeatureFetcher
from
.
minibatch
_sampler
import
Minibatch
Sampler
from
.
item
_sampler
import
Item
Sampler
from
.utils
import
datapipe_graph_to_adjlist
from
.utils
import
datapipe_graph_to_adjlist
...
@@ -26,7 +26,7 @@ class SingleProcessDataLoader(torch.utils.data.DataLoader):
...
@@ -26,7 +26,7 @@ class SingleProcessDataLoader(torch.utils.data.DataLoader):
# dataloader as-is.
# dataloader as-is.
#
#
# The exception is that batch_size should be None, since we already
# The exception is that batch_size should be None, since we already
# have minibatch sampling and collating in
Minibatch
Sampler.
# have minibatch sampling and collating in
Item
Sampler.
def
__init__
(
self
,
datapipe
):
def
__init__
(
self
,
datapipe
):
super
().
__init__
(
datapipe
,
batch_size
=
None
,
num_workers
=
0
)
super
().
__init__
(
datapipe
,
batch_size
=
None
,
num_workers
=
0
)
...
@@ -77,7 +77,7 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader):
...
@@ -77,7 +77,7 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader):
def
__init__
(
self
,
datapipe
,
num_workers
=
0
):
def
__init__
(
self
,
datapipe
,
num_workers
=
0
):
# Multiprocessing requires two modifications to the datapipe:
# Multiprocessing requires two modifications to the datapipe:
#
#
# 1. Insert a stage after
Minibatch
Sampler to distribute the
# 1. Insert a stage after
Item
Sampler to distribute the
# minibatches evenly across processes.
# minibatches evenly across processes.
# 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe
# 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe
# of the FeatureFetcher with a multiprocessing PyTorch DataLoader.
# of the FeatureFetcher with a multiprocessing PyTorch DataLoader.
...
@@ -88,16 +88,16 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader):
...
@@ -88,16 +88,16 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader):
# (1) Insert minibatch distribution.
# (1) Insert minibatch distribution.
# TODO(BarclayII): Currently I'm using sharding_filter() as a
# TODO(BarclayII): Currently I'm using sharding_filter() as a
# concept demonstration. Later on minibatch distribution should be
# concept demonstration. Later on minibatch distribution should be
# merged into
Minibatch
Sampler to maximize efficiency.
# merged into
Item
Sampler to maximize efficiency.
minibatch
_samplers
=
dp_utils
.
find_dps
(
item
_samplers
=
dp_utils
.
find_dps
(
datapipe_graph
,
datapipe_graph
,
Minibatch
Sampler
,
Item
Sampler
,
)
)
for
minibatch
_sampler
in
minibatch
_samplers
:
for
item
_sampler
in
item
_samplers
:
datapipe_graph
=
dp_utils
.
replace_dp
(
datapipe_graph
=
dp_utils
.
replace_dp
(
datapipe_graph
,
datapipe_graph
,
minibatch
_sampler
,
item
_sampler
,
minibatch
_sampler
.
sharding_filter
(),
item
_sampler
.
sharding_filter
(),
)
)
# (2) Cut datapipe at FeatureFetcher and wrap.
# (2) Cut datapipe at FeatureFetcher and wrap.
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
240e28a2
...
@@ -64,10 +64,10 @@ class NeighborSampler(SubgraphSampler):
...
@@ -64,10 +64,10 @@ class NeighborSampler(SubgraphSampler):
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> item_set = gb.ItemSet(node_pairs)
>>>
minibatch
_sampler = gb.
Minibatch
Sampler(
>>>
item
_sampler = gb.
Item
Sampler(
...item_set, batch_size=1,
...item_set, batch_size=1,
...)
...)
>>> data_block_converter = Mapper(
minibatch
_sampler, to_link_block)
>>> data_block_converter = Mapper(
item
_sampler, to_link_block)
>>> neg_sampler = gb.UniformNegativeSampler(
>>> neg_sampler = gb.UniformNegativeSampler(
...data_block_converter, 2, data_format, graph)
...data_block_converter, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...
@@ -175,10 +175,10 @@ class LayerNeighborSampler(NeighborSampler):
...
@@ -175,10 +175,10 @@ class LayerNeighborSampler(NeighborSampler):
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> item_set = gb.ItemSet(node_pairs)
>>>
minibatch
_sampler = gb.
Minibatch
Sampler(
>>>
item
_sampler = gb.
Item
Sampler(
...item_set, batch_size=1,
...item_set, batch_size=1,
...)
...)
>>> data_block_converter = Mapper(
minibatch
_sampler, to_link_block)
>>> data_block_converter = Mapper(
item
_sampler, to_link_block)
>>> neg_sampler = gb.UniformNegativeSampler(
>>> neg_sampler = gb.UniformNegativeSampler(
...data_block_converter, 2, data_format, graph)
...data_block_converter, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...
...
python/dgl/graphbolt/impl/uniform_negative_sampler.py
View file @
240e28a2
...
@@ -44,11 +44,11 @@ class UniformNegativeSampler(NegativeSampler):
...
@@ -44,11 +44,11 @@ class UniformNegativeSampler(NegativeSampler):
>>> output_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> output_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> item_set = gb.ItemSet(node_pairs)
>>>
minibatch
_sampler = gb.
Minibatch
Sampler(
>>>
item
_sampler = gb.
Item
Sampler(
...item_set, batch_size=1,
...item_set, batch_size=1,
...)
...)
>>> neg_sampler = gb.UniformNegativeSampler(
>>> neg_sampler = gb.UniformNegativeSampler(
...
minibatch
_sampler, 2, output_format, graph)
...
item
_sampler, 2, output_format, graph)
>>> for data in neg_sampler:
>>> for data in neg_sampler:
... print(data)
... print(data)
...
...
...
@@ -62,11 +62,11 @@ class UniformNegativeSampler(NegativeSampler):
...
@@ -62,11 +62,11 @@ class UniformNegativeSampler(NegativeSampler):
>>> output_format = gb.LinkPredictionEdgeFormat.CONDITIONED
>>> output_format = gb.LinkPredictionEdgeFormat.CONDITIONED
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> item_set = gb.ItemSet(node_pairs)
>>>
minibatch
_sampler = gb.
Minibatch
Sampler(
>>>
item
_sampler = gb.
Item
Sampler(
...item_set, batch_size=1,
...item_set, batch_size=1,
...)
...)
>>> neg_sampler = gb.UniformNegativeSampler(
>>> neg_sampler = gb.UniformNegativeSampler(
...
minibatch
_sampler, 2, output_format, graph)
...
item
_sampler, 2, output_format, graph)
>>> for data in neg_sampler:
>>> for data in neg_sampler:
... print(data)
... print(data)
...
...
...
...
python/dgl/graphbolt/
minibatch
_sampler.py
→
python/dgl/graphbolt/
item
_sampler.py
View file @
240e28a2
"""
Minibatch
Sampler"""
"""
Item
Sampler"""
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
functools
import
partial
from
functools
import
partial
...
@@ -11,17 +11,17 @@ from ..batch import batch as dgl_batch
...
@@ -11,17 +11,17 @@ from ..batch import batch as dgl_batch
from
..heterograph
import
DGLGraph
from
..heterograph
import
DGLGraph
from
.itemset
import
ItemSet
,
ItemSetDict
from
.itemset
import
ItemSet
,
ItemSetDict
__all__
=
[
"
Minibatch
Sampler"
]
__all__
=
[
"
Item
Sampler"
]
class
Minibatch
Sampler
(
IterDataPipe
):
class
Item
Sampler
(
IterDataPipe
):
"""
Minibatch
Sampler.
"""
Item
Sampler.
Creates
mini-batches
of data which could be node/edge IDs, node pairs with
Creates
item subset
of data which could be node/edge IDs, node pairs with
or without labels, head/tail/negative_tails, DGLGraphs and heterogeneous
or without labels, head/tail/negative_tails, DGLGraphs and heterogeneous
counterparts.
counterparts.
Note: This class `
Minibatch
Sampler` is not decorated with
Note: This class `
Item
Sampler` is not decorated with
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it
does not support function-like call. But any iterable datapipes from
does not support function-like call. But any iterable datapipes from
`torchdata` can be further appended.
`torchdata` can be further appended.
...
@@ -29,7 +29,7 @@ class MinibatchSampler(IterDataPipe):
...
@@ -29,7 +29,7 @@ class MinibatchSampler(IterDataPipe):
Parameters
Parameters
----------
----------
item_set : ItemSet or ItemSetDict
item_set : ItemSet or ItemSetDict
Data to be sampled
for mini-batches
.
Data to be sampled.
batch_size : int
batch_size : int
The size of each batch.
The size of each batch.
drop_last : bool
drop_last : bool
...
@@ -43,18 +43,18 @@ class MinibatchSampler(IterDataPipe):
...
@@ -43,18 +43,18 @@ class MinibatchSampler(IterDataPipe):
>>> import torch
>>> import torch
>>> from dgl import graphbolt as gb
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>>
minibatch
_sampler = gb.
Minibatch
Sampler(
>>>
item
_sampler = gb.
Item
Sampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
... )
>>> list(
minibatch
_sampler)
>>> list(
item
_sampler)
[tensor([1, 2, 5, 7]), tensor([3, 0, 9, 4]), tensor([6, 8])]
[tensor([1, 2, 5, 7]), tensor([3, 0, 9, 4]), tensor([6, 8])]
2. Node pairs.
2. Node pairs.
>>> item_set = gb.ItemSet((torch.arange(0, 10), torch.arange(10, 20)))
>>> item_set = gb.ItemSet((torch.arange(0, 10), torch.arange(10, 20)))
>>>
minibatch
_sampler = gb.
Minibatch
Sampler(
>>>
item
_sampler = gb.
Item
Sampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
... )
>>> list(
minibatch
_sampler)
>>> list(
item
_sampler)
[[tensor([9, 8, 3, 1]), tensor([19, 18, 13, 11])], [tensor([2, 5, 7, 4]),
[[tensor([9, 8, 3, 1]), tensor([19, 18, 13, 11])], [tensor([2, 5, 7, 4]),
tensor([12, 15, 17, 14])], [tensor([0, 6]), tensor([10, 16])]
tensor([12, 15, 17, 14])], [tensor([0, 6]), tensor([10, 16])]
...
@@ -62,8 +62,8 @@ class MinibatchSampler(IterDataPipe):
...
@@ -62,8 +62,8 @@ class MinibatchSampler(IterDataPipe):
>>> item_set = gb.ItemSet(
>>> item_set = gb.ItemSet(
... (torch.arange(0, 5), torch.arange(5, 10), torch.arange(10, 15))
... (torch.arange(0, 5), torch.arange(5, 10), torch.arange(10, 15))
... )
... )
>>>
minibatch
_sampler = gb.
Minibatch
Sampler(item_set, 3)
>>>
item
_sampler = gb.
Item
Sampler(item_set, 3)
>>> list(
minibatch
_sampler)
>>> list(
item
_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]), tensor([10, 11, 12])],
[[tensor([0, 1, 2]), tensor([5, 6, 7]), tensor([10, 11, 12])],
[tensor([3, 4]), tensor([8, 9]), tensor([13, 14])]]
[tensor([3, 4]), tensor([8, 9]), tensor([13, 14])]]
...
@@ -72,8 +72,8 @@ class MinibatchSampler(IterDataPipe):
...
@@ -72,8 +72,8 @@ class MinibatchSampler(IterDataPipe):
>>> tails = torch.arange(5, 10)
>>> tails = torch.arange(5, 10)
>>> negative_tails = torch.stack((heads + 1, heads + 2), dim=-1)
>>> negative_tails = torch.stack((heads + 1, heads + 2), dim=-1)
>>> item_set = gb.ItemSet((heads, tails, negative_tails))
>>> item_set = gb.ItemSet((heads, tails, negative_tails))
>>>
minibatch
_sampler = gb.
Minibatch
Sampler(item_set, 3)
>>>
item
_sampler = gb.
Item
Sampler(item_set, 3)
>>> list(
minibatch
_sampler)
>>> list(
item
_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]),
[[tensor([0, 1, 2]), tensor([5, 6, 7]),
tensor([[1, 2], [2, 3], [3, 4]])],
tensor([[1, 2], [2, 3], [3, 4]])],
[tensor([3, 4]), tensor([8, 9]), tensor([[4, 5], [5, 6]])]]
[tensor([3, 4]), tensor([8, 9]), tensor([[4, 5], [5, 6]])]]
...
@@ -82,8 +82,8 @@ class MinibatchSampler(IterDataPipe):
...
@@ -82,8 +82,8 @@ class MinibatchSampler(IterDataPipe):
>>> import dgl
>>> import dgl
>>> graphs = [ dgl.rand_graph(10, 20) for _ in range(5) ]
>>> graphs = [ dgl.rand_graph(10, 20) for _ in range(5) ]
>>> item_set = gb.ItemSet(graphs)
>>> item_set = gb.ItemSet(graphs)
>>>
minibatch
_sampler = gb.
Minibatch
Sampler(item_set, 3)
>>>
item
_sampler = gb.
Item
Sampler(item_set, 3)
>>> list(
minibatch
_sampler)
>>> list(
item
_sampler)
[Graph(num_nodes=30, num_edges=60,
[Graph(num_nodes=30, num_edges=60,
ndata_schemes={}
ndata_schemes={}
edata_schemes={}),
edata_schemes={}),
...
@@ -94,7 +94,7 @@ class MinibatchSampler(IterDataPipe):
...
@@ -94,7 +94,7 @@ class MinibatchSampler(IterDataPipe):
6. Further process batches with other datapipes such as
6. Further process batches with other datapipes such as
`torchdata.datapipes.iter.Mapper`.
`torchdata.datapipes.iter.Mapper`.
>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> data_pipe = gb.
Minibatch
Sampler(item_set, 4)
>>> data_pipe = gb.
Item
Sampler(item_set, 4)
>>> def add_one(batch):
>>> def add_one(batch):
... return batch + 1
... return batch + 1
>>> data_pipe = data_pipe.map(add_one)
>>> data_pipe = data_pipe.map(add_one)
...
@@ -107,8 +107,8 @@ class MinibatchSampler(IterDataPipe):
...
@@ -107,8 +107,8 @@ class MinibatchSampler(IterDataPipe):
... "item": gb.ItemSet(torch.arange(0, 6)),
... "item": gb.ItemSet(torch.arange(0, 6)),
... }
... }
>>> item_set = gb.ItemSetDict(ids)
>>> item_set = gb.ItemSetDict(ids)
>>>
minibatch
_sampler = gb.
Minibatch
Sampler(item_set, 4)
>>>
item
_sampler = gb.
Item
Sampler(item_set, 4)
>>> list(
minibatch
_sampler)
>>> list(
item
_sampler)
[{'user': tensor([0, 1, 2, 3])},
[{'user': tensor([0, 1, 2, 3])},
{'item': tensor([0, 1, 2]), 'user': tensor([4])},
{'item': tensor([0, 1, 2]), 'user': tensor([4])},
{'item': tensor([3, 4, 5])}]
{'item': tensor([3, 4, 5])}]
...
@@ -120,8 +120,8 @@ class MinibatchSampler(IterDataPipe):
...
@@ -120,8 +120,8 @@ class MinibatchSampler(IterDataPipe):
... "user:like:item": gb.ItemSet(node_pairs_like),
... "user:like:item": gb.ItemSet(node_pairs_like),
... "user:follow:user": gb.ItemSet(node_pairs_follow),
... "user:follow:user": gb.ItemSet(node_pairs_follow),
... })
... })
>>>
minibatch
_sampler = gb.
Minibatch
Sampler(item_set, 4)
>>>
item
_sampler = gb.
Item
Sampler(item_set, 4)
>>> list(
minibatch
_sampler)
>>> list(
item
_sampler)
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{"user:like:item": [tensor([4]), tensor([4])],
{"user:like:item": [tensor([4]), tensor([4])],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8])]},
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8])]},
...
@@ -136,8 +136,8 @@ class MinibatchSampler(IterDataPipe):
...
@@ -136,8 +136,8 @@ class MinibatchSampler(IterDataPipe):
... "user:like:item": gb.ItemSet(like),
... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow),
... "user:follow:user": gb.ItemSet(follow),
... })
... })
>>>
minibatch
_sampler = gb.
Minibatch
Sampler(item_set, 4)
>>>
item
_sampler = gb.
Item
Sampler(item_set, 4)
>>> list(
minibatch
_sampler)
>>> list(
item
_sampler)
[{"user:like:item":
[{"user:like:item":
[tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
[tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{"user:like:item": [tensor([4]), tensor([4]), tensor([4])],
{"user:like:item": [tensor([4]), tensor([4]), tensor([4])],
...
@@ -157,8 +157,8 @@ class MinibatchSampler(IterDataPipe):
...
@@ -157,8 +157,8 @@ class MinibatchSampler(IterDataPipe):
... "user:like:item": gb.ItemSet(like),
... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow),
... "user:follow:user": gb.ItemSet(follow),
... })
... })
>>>
minibatch
_sampler = gb.
Minibatch
Sampler(item_set, 4)
>>>
item
_sampler = gb.
Item
Sampler(item_set, 4)
>>> list(
minibatch
_sampler)
>>> list(
item
_sampler)
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]),
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]),
tensor([[ 5, 6], [ 7, 8], [ 9, 10], [11, 12]])]},
tensor([[ 5, 6], [ 7, 8], [ 9, 10], [11, 12]])]},
{"user:like:item": [tensor([4]), tensor([4]), tensor([[13, 14]])],
{"user:like:item": [tensor([4]), tensor([4]), tensor([[13, 14]])],
...
...
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
View file @
240e28a2
...
@@ -21,8 +21,8 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
...
@@ -21,8 +21,8 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
)
)
)
)
batch_size
=
10
batch_size
=
10
minibatch
_sampler
=
gb
.
Minibatch
Sampler
(
item_set
,
batch_size
=
batch_size
)
item
_sampler
=
gb
.
Item
Sampler
(
item_set
,
batch_size
=
batch_size
)
data_block_converter
=
Mapper
(
minibatch
_sampler
,
to_data_block
)
data_block_converter
=
Mapper
(
item
_sampler
,
to_data_block
)
# Construct NegativeSampler.
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
negative_sampler
=
gb
.
UniformNegativeSampler
(
data_block_converter
,
data_block_converter
,
...
@@ -54,8 +54,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
...
@@ -54,8 +54,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
)
)
)
)
batch_size
=
10
batch_size
=
10
minibatch
_sampler
=
gb
.
Minibatch
Sampler
(
item_set
,
batch_size
=
batch_size
)
item
_sampler
=
gb
.
Item
Sampler
(
item_set
,
batch_size
=
batch_size
)
data_block_converter
=
Mapper
(
minibatch
_sampler
,
to_data_block
)
data_block_converter
=
Mapper
(
item
_sampler
,
to_data_block
)
# Construct NegativeSampler.
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
negative_sampler
=
gb
.
UniformNegativeSampler
(
data_block_converter
,
data_block_converter
,
...
@@ -90,8 +90,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
...
@@ -90,8 +90,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
)
)
)
)
batch_size
=
10
batch_size
=
10
minibatch
_sampler
=
gb
.
Minibatch
Sampler
(
item_set
,
batch_size
=
batch_size
)
item
_sampler
=
gb
.
Item
Sampler
(
item_set
,
batch_size
=
batch_size
)
data_block_converter
=
Mapper
(
minibatch
_sampler
,
to_data_block
)
data_block_converter
=
Mapper
(
item
_sampler
,
to_data_block
)
# Construct NegativeSampler.
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
negative_sampler
=
gb
.
UniformNegativeSampler
(
data_block_converter
,
data_block_converter
,
...
@@ -124,8 +124,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
...
@@ -124,8 +124,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
)
)
)
)
batch_size
=
10
batch_size
=
10
minibatch
_sampler
=
gb
.
Minibatch
Sampler
(
item_set
,
batch_size
=
batch_size
)
item
_sampler
=
gb
.
Item
Sampler
(
item_set
,
batch_size
=
batch_size
)
data_block_converter
=
Mapper
(
minibatch
_sampler
,
to_data_block
)
data_block_converter
=
Mapper
(
item
_sampler
,
to_data_block
)
# Construct NegativeSampler.
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
negative_sampler
=
gb
.
UniformNegativeSampler
(
data_block_converter
,
data_block_converter
,
...
@@ -199,8 +199,8 @@ def test_NegativeSampler_Hetero_Data(format):
...
@@ -199,8 +199,8 @@ def test_NegativeSampler_Hetero_Data(format):
}
}
)
)
minibatch_dp
=
gb
.
Minibatch
Sampler
(
itemset
,
batch_size
=
2
)
item_sampler_dp
=
gb
.
Item
Sampler
(
itemset
,
batch_size
=
2
)
data_block_converter
=
Mapper
(
minibatch
_dp
,
to_link_block
)
data_block_converter
=
Mapper
(
item_sampler
_dp
,
to_link_block
)
negative_dp
=
gb
.
UniformNegativeSampler
(
negative_dp
=
gb
.
UniformNegativeSampler
(
data_block_converter
,
1
,
format
,
graph
data_block_converter
,
1
,
format
,
graph
)
)
...
...
tests/python/pytorch/graphbolt/test_base.py
View file @
240e28a2
...
@@ -10,7 +10,7 @@ import torch
...
@@ -10,7 +10,7 @@ import torch
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"cpu"
,
"CopyTo needs GPU to test"
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"cpu"
,
"CopyTo needs GPU to test"
)
def
test_CopyTo
():
def
test_CopyTo
():
dp
=
gb
.
Minibatch
Sampler
(
torch
.
randn
(
20
),
4
)
dp
=
gb
.
Item
Sampler
(
torch
.
randn
(
20
),
4
)
dp
=
gb
.
CopyTo
(
dp
,
"cuda"
)
dp
=
gb
.
CopyTo
(
dp
,
"cuda"
)
for
data
in
dp
:
for
data
in
dp
:
...
...
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
240e28a2
...
@@ -16,10 +16,10 @@ def test_FeatureFetcher_homo():
...
@@ -16,10 +16,10 @@ def test_FeatureFetcher_homo():
feature_store
=
gb
.
BasicFeatureStore
(
features
)
feature_store
=
gb
.
BasicFeatureStore
(
features
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
))
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
))
minibatch_dp
=
gb
.
Minibatch
Sampler
(
itemset
,
batch_size
=
2
)
item_sampler_dp
=
gb
.
Item
Sampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
minibatch
_dp
,
gb_test_utils
.
to_node_block
)
data_block_converter
=
Mapper
(
item_sampler
_dp
,
gb_test_utils
.
to_node_block
)
sampler_dp
=
gb
.
NeighborSampler
(
data_block_converter
,
graph
,
fanouts
)
sampler_dp
=
gb
.
NeighborSampler
(
data_block_converter
,
graph
,
fanouts
)
fetcher_dp
=
gb
.
FeatureFetcher
(
sampler_dp
,
feature_store
,
[
"a"
],
[
"b"
])
fetcher_dp
=
gb
.
FeatureFetcher
(
sampler_dp
,
feature_store
,
[
"a"
],
[
"b"
])
...
@@ -52,8 +52,8 @@ def test_FeatureFetcher_with_edges_homo():
...
@@ -52,8 +52,8 @@ def test_FeatureFetcher_with_edges_homo():
feature_store
=
gb
.
BasicFeatureStore
(
features
)
feature_store
=
gb
.
BasicFeatureStore
(
features
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
))
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
))
minibatch_dp
=
gb
.
Minibatch
Sampler
(
itemset
,
batch_size
=
2
)
item_sampler_dp
=
gb
.
Item
Sampler
(
itemset
,
batch_size
=
2
)
converter_dp
=
Mapper
(
minibatch
_dp
,
add_node_and_edge_ids
)
converter_dp
=
Mapper
(
item_sampler
_dp
,
add_node_and_edge_ids
)
fetcher_dp
=
gb
.
FeatureFetcher
(
converter_dp
,
feature_store
,
[
"a"
],
[
"b"
])
fetcher_dp
=
gb
.
FeatureFetcher
(
converter_dp
,
feature_store
,
[
"a"
],
[
"b"
])
assert
len
(
list
(
fetcher_dp
))
==
5
assert
len
(
list
(
fetcher_dp
))
==
5
...
@@ -103,10 +103,10 @@ def test_FeatureFetcher_hetero():
...
@@ -103,10 +103,10 @@ def test_FeatureFetcher_hetero():
"n2"
:
gb
.
ItemSet
(
torch
.
LongTensor
([
0
,
1
,
2
])),
"n2"
:
gb
.
ItemSet
(
torch
.
LongTensor
([
0
,
1
,
2
])),
}
}
)
)
minibatch_dp
=
gb
.
Minibatch
Sampler
(
itemset
,
batch_size
=
2
)
item_sampler_dp
=
gb
.
Item
Sampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
minibatch
_dp
,
gb_test_utils
.
to_node_block
)
data_block_converter
=
Mapper
(
item_sampler
_dp
,
gb_test_utils
.
to_node_block
)
sampler_dp
=
gb
.
NeighborSampler
(
data_block_converter
,
graph
,
fanouts
)
sampler_dp
=
gb
.
NeighborSampler
(
data_block_converter
,
graph
,
fanouts
)
fetcher_dp
=
gb
.
FeatureFetcher
(
fetcher_dp
=
gb
.
FeatureFetcher
(
sampler_dp
,
feature_store
,
{
"n1"
:
[
"a"
],
"n2"
:
[
"a"
]}
sampler_dp
,
feature_store
,
{
"n1"
:
[
"a"
],
"n2"
:
[
"a"
]}
...
@@ -148,8 +148,8 @@ def test_FeatureFetcher_with_edges_hetero():
...
@@ -148,8 +148,8 @@ def test_FeatureFetcher_with_edges_hetero():
"n1"
:
gb
.
ItemSet
(
torch
.
randint
(
0
,
20
,
(
10
,))),
"n1"
:
gb
.
ItemSet
(
torch
.
randint
(
0
,
20
,
(
10
,))),
}
}
)
)
minibatch_dp
=
gb
.
Minibatch
Sampler
(
itemset
,
batch_size
=
2
)
item_sampler_dp
=
gb
.
Item
Sampler
(
itemset
,
batch_size
=
2
)
converter_dp
=
Mapper
(
minibatch
_dp
,
add_node_and_edge_ids
)
converter_dp
=
Mapper
(
item_sampler
_dp
,
add_node_and_edge_ids
)
fetcher_dp
=
gb
.
FeatureFetcher
(
fetcher_dp
=
gb
.
FeatureFetcher
(
converter_dp
,
feature_store
,
{
"n1"
:
[
"a"
]},
{
"n1:e1:n2"
:
[
"a"
]}
converter_dp
,
feature_store
,
{
"n1"
:
[
"a"
]},
{
"n1:e1:n2"
:
[
"a"
]}
)
)
...
...
tests/python/pytorch/graphbolt/test_
minibatch
_sampler.py
→
tests/python/pytorch/graphbolt/test_
item
_sampler.py
View file @
240e28a2
...
@@ -12,11 +12,11 @@ def test_ItemSet_node_ids(batch_size, shuffle, drop_last):
...
@@ -12,11 +12,11 @@ def test_ItemSet_node_ids(batch_size, shuffle, drop_last):
# Node IDs.
# Node IDs.
num_ids
=
103
num_ids
=
103
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
num_ids
))
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
num_ids
))
minibatch
_sampler
=
gb
.
Minibatch
Sampler
(
item
_sampler
=
gb
.
Item
Sampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
)
minibatch_ids
=
[]
minibatch_ids
=
[]
for
i
,
minibatch
in
enumerate
(
minibatch
_sampler
):
for
i
,
minibatch
in
enumerate
(
item
_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
if
not
is_last
or
num_ids
%
batch_size
==
0
:
assert
len
(
minibatch
)
==
batch_size
assert
len
(
minibatch
)
==
batch_size
...
@@ -43,12 +43,12 @@ def test_ItemSet_graphs(batch_size, shuffle, drop_last):
...
@@ -43,12 +43,12 @@ def test_ItemSet_graphs(batch_size, shuffle, drop_last):
for
i
in
range
(
num_graphs
)
for
i
in
range
(
num_graphs
)
]
]
item_set
=
gb
.
ItemSet
(
graphs
)
item_set
=
gb
.
ItemSet
(
graphs
)
minibatch
_sampler
=
gb
.
Minibatch
Sampler
(
item
_sampler
=
gb
.
Item
Sampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
)
minibatch_num_nodes
=
[]
minibatch_num_nodes
=
[]
minibatch_num_edges
=
[]
minibatch_num_edges
=
[]
for
i
,
minibatch
in
enumerate
(
minibatch
_sampler
):
for
i
,
minibatch
in
enumerate
(
item
_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
num_graphs
is_last
=
(
i
+
1
)
*
batch_size
>=
num_graphs
if
not
is_last
or
num_graphs
%
batch_size
==
0
:
if
not
is_last
or
num_graphs
%
batch_size
==
0
:
assert
minibatch
.
batch_size
==
batch_size
assert
minibatch
.
batch_size
==
batch_size
...
@@ -79,12 +79,12 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
...
@@ -79,12 +79,12 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
num_ids
=
103
num_ids
=
103
node_pairs
=
(
torch
.
arange
(
0
,
num_ids
),
torch
.
arange
(
num_ids
,
num_ids
*
2
))
node_pairs
=
(
torch
.
arange
(
0
,
num_ids
),
torch
.
arange
(
num_ids
,
num_ids
*
2
))
item_set
=
gb
.
ItemSet
(
node_pairs
)
item_set
=
gb
.
ItemSet
(
node_pairs
)
minibatch
_sampler
=
gb
.
Minibatch
Sampler
(
item
_sampler
=
gb
.
Item
Sampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
)
src_ids
=
[]
src_ids
=
[]
dst_ids
=
[]
dst_ids
=
[]
for
i
,
(
src
,
dst
)
in
enumerate
(
minibatch
_sampler
):
for
i
,
(
src
,
dst
)
in
enumerate
(
item
_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
if
not
is_last
or
num_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
expected_batch_size
=
batch_size
...
@@ -115,13 +115,13 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
...
@@ -115,13 +115,13 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
node_pairs
=
(
torch
.
arange
(
0
,
num_ids
),
torch
.
arange
(
num_ids
,
num_ids
*
2
))
node_pairs
=
(
torch
.
arange
(
0
,
num_ids
),
torch
.
arange
(
num_ids
,
num_ids
*
2
))
labels
=
torch
.
arange
(
0
,
num_ids
)
labels
=
torch
.
arange
(
0
,
num_ids
)
item_set
=
gb
.
ItemSet
((
node_pairs
[
0
],
node_pairs
[
1
],
labels
))
item_set
=
gb
.
ItemSet
((
node_pairs
[
0
],
node_pairs
[
1
],
labels
))
minibatch
_sampler
=
gb
.
Minibatch
Sampler
(
item
_sampler
=
gb
.
Item
Sampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
)
src_ids
=
[]
src_ids
=
[]
dst_ids
=
[]
dst_ids
=
[]
labels
=
[]
labels
=
[]
for
i
,
(
src
,
dst
,
label
)
in
enumerate
(
minibatch
_sampler
):
for
i
,
(
src
,
dst
,
label
)
in
enumerate
(
item
_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
if
not
is_last
or
num_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
expected_batch_size
=
batch_size
...
@@ -163,13 +163,13 @@ def test_ItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last):
...
@@ -163,13 +163,13 @@ def test_ItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last):
assert
heads
[
i
]
==
head
assert
heads
[
i
]
==
head
assert
tails
[
i
]
==
tail
assert
tails
[
i
]
==
tail
assert
torch
.
equal
(
neg_tails
[
i
],
negs
)
assert
torch
.
equal
(
neg_tails
[
i
],
negs
)
minibatch
_sampler
=
gb
.
Minibatch
Sampler
(
item
_sampler
=
gb
.
Item
Sampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
)
head_ids
=
[]
head_ids
=
[]
tail_ids
=
[]
tail_ids
=
[]
negs_ids
=
[]
negs_ids
=
[]
for
i
,
(
head
,
tail
,
negs
)
in
enumerate
(
minibatch
_sampler
):
for
i
,
(
head
,
tail
,
negs
)
in
enumerate
(
item
_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
if
not
is_last
or
num_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
expected_batch_size
=
batch_size
...
@@ -204,7 +204,7 @@ def test_append_with_other_datapipes():
...
@@ -204,7 +204,7 @@ def test_append_with_other_datapipes():
num_ids
=
100
num_ids
=
100
batch_size
=
4
batch_size
=
4
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
num_ids
))
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
num_ids
))
data_pipe
=
gb
.
Minibatch
Sampler
(
item_set
,
batch_size
)
data_pipe
=
gb
.
Item
Sampler
(
item_set
,
batch_size
)
# torchdata.datapipes.iter.Enumerator
# torchdata.datapipes.iter.Enumerator
data_pipe
=
data_pipe
.
enumerate
()
data_pipe
=
data_pipe
.
enumerate
()
for
i
,
(
idx
,
data
)
in
enumerate
(
data_pipe
):
for
i
,
(
idx
,
data
)
in
enumerate
(
data_pipe
):
...
@@ -226,11 +226,11 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last):
...
@@ -226,11 +226,11 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last):
for
key
,
value
in
ids
.
items
():
for
key
,
value
in
ids
.
items
():
chained_ids
+=
[(
key
,
v
)
for
v
in
value
]
chained_ids
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
gb
.
ItemSetDict
(
ids
)
item_set
=
gb
.
ItemSetDict
(
ids
)
minibatch
_sampler
=
gb
.
Minibatch
Sampler
(
item
_sampler
=
gb
.
Item
Sampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
)
minibatch_ids
=
[]
minibatch_ids
=
[]
for
i
,
batch
in
enumerate
(
minibatch
_sampler
):
for
i
,
batch
in
enumerate
(
item
_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
if
not
is_last
or
num_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
expected_batch_size
=
batch_size
...
@@ -270,12 +270,12 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
...
@@ -270,12 +270,12 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
"user:follow:user"
:
gb
.
ItemSet
(
node_pairs_1
),
"user:follow:user"
:
gb
.
ItemSet
(
node_pairs_1
),
}
}
item_set
=
gb
.
ItemSetDict
(
node_pairs_dict
)
item_set
=
gb
.
ItemSetDict
(
node_pairs_dict
)
minibatch
_sampler
=
gb
.
Minibatch
Sampler
(
item
_sampler
=
gb
.
Item
Sampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
)
src_ids
=
[]
src_ids
=
[]
dst_ids
=
[]
dst_ids
=
[]
for
i
,
batch
in
enumerate
(
minibatch
_sampler
):
for
i
,
batch
in
enumerate
(
item
_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
total_ids
is_last
=
(
i
+
1
)
*
batch_size
>=
total_ids
if
not
is_last
or
total_ids
%
batch_size
==
0
:
if
not
is_last
or
total_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
expected_batch_size
=
batch_size
...
@@ -327,13 +327,13 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
...
@@ -327,13 +327,13 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
),
),
}
}
item_set
=
gb
.
ItemSetDict
(
node_pairs_dict
)
item_set
=
gb
.
ItemSetDict
(
node_pairs_dict
)
minibatch
_sampler
=
gb
.
Minibatch
Sampler
(
item
_sampler
=
gb
.
Item
Sampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
)
src_ids
=
[]
src_ids
=
[]
dst_ids
=
[]
dst_ids
=
[]
labels
=
[]
labels
=
[]
for
i
,
batch
in
enumerate
(
minibatch
_sampler
):
for
i
,
batch
in
enumerate
(
item
_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
total_ids
is_last
=
(
i
+
1
)
*
batch_size
>=
total_ids
if
not
is_last
or
total_ids
%
batch_size
==
0
:
if
not
is_last
or
total_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
expected_batch_size
=
batch_size
...
@@ -384,13 +384,13 @@ def test_ItemSetDict_head_tail_neg_tails(batch_size, shuffle, drop_last):
...
@@ -384,13 +384,13 @@ def test_ItemSetDict_head_tail_neg_tails(batch_size, shuffle, drop_last):
"user:follow:user"
:
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
)),
"user:follow:user"
:
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
)),
}
}
item_set
=
gb
.
ItemSetDict
(
data_dict
)
item_set
=
gb
.
ItemSetDict
(
data_dict
)
minibatch
_sampler
=
gb
.
Minibatch
Sampler
(
item
_sampler
=
gb
.
Item
Sampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
)
head_ids
=
[]
head_ids
=
[]
tail_ids
=
[]
tail_ids
=
[]
negs_ids
=
[]
negs_ids
=
[]
for
i
,
batch
in
enumerate
(
minibatch
_sampler
):
for
i
,
batch
in
enumerate
(
item
_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
total_ids
is_last
=
(
i
+
1
)
*
batch_size
>=
total_ids
if
not
is_last
or
total_ids
%
batch_size
==
0
:
if
not
is_last
or
total_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
expected_batch_size
=
batch_size
...
...
tests/python/pytorch/graphbolt/test_multi_process_dataloader.py
View file @
240e28a2
...
@@ -22,8 +22,8 @@ def test_DataLoader():
...
@@ -22,8 +22,8 @@ def test_DataLoader():
features
[
keys
[
1
]]
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randn
(
200
,
4
))
features
[
keys
[
1
]]
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randn
(
200
,
4
))
feature_store
=
dgl
.
graphbolt
.
BasicFeatureStore
(
features
)
feature_store
=
dgl
.
graphbolt
.
BasicFeatureStore
(
features
)
minibatch
_sampler
=
dgl
.
graphbolt
.
Minibatch
Sampler
(
itemset
,
batch_size
=
B
)
item
_sampler
=
dgl
.
graphbolt
.
Item
Sampler
(
itemset
,
batch_size
=
B
)
block_converter
=
Mapper
(
minibatch
_sampler
,
gb_test_utils
.
to_node_block
)
block_converter
=
Mapper
(
item
_sampler
,
gb_test_utils
.
to_node_block
)
subgraph_sampler
=
dgl
.
graphbolt
.
NeighborSampler
(
subgraph_sampler
=
dgl
.
graphbolt
.
NeighborSampler
(
block_converter
,
block_converter
,
graph
,
graph
,
...
...
tests/python/pytorch/graphbolt/test_single_process_dataloader.py
View file @
240e28a2
...
@@ -24,8 +24,8 @@ def test_DataLoader():
...
@@ -24,8 +24,8 @@ def test_DataLoader():
features
[
keys
[
1
]]
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randn
(
200
,
4
))
features
[
keys
[
1
]]
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randn
(
200
,
4
))
feature_store
=
dgl
.
graphbolt
.
BasicFeatureStore
(
features
)
feature_store
=
dgl
.
graphbolt
.
BasicFeatureStore
(
features
)
minibatch
_sampler
=
dgl
.
graphbolt
.
Minibatch
Sampler
(
itemset
,
batch_size
=
B
)
item
_sampler
=
dgl
.
graphbolt
.
Item
Sampler
(
itemset
,
batch_size
=
B
)
block_converter
=
Mapper
(
minibatch
_sampler
,
to_node_block
)
block_converter
=
Mapper
(
item
_sampler
,
to_node_block
)
subgraph_sampler
=
dgl
.
graphbolt
.
NeighborSampler
(
subgraph_sampler
=
dgl
.
graphbolt
.
NeighborSampler
(
block_converter
,
block_converter
,
graph
,
graph
,
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
240e28a2
...
@@ -15,10 +15,10 @@ def to_node_block(data):
...
@@ -15,10 +15,10 @@ def to_node_block(data):
def
test_SubgraphSampler_Node
(
labor
):
def
test_SubgraphSampler_Node
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
))
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
))
minibatch_dp
=
gb
.
Minibatch
Sampler
(
itemset
,
batch_size
=
2
)
item_sampler_dp
=
gb
.
Item
Sampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
minibatch
_dp
,
to_node_block
)
data_block_converter
=
Mapper
(
item_sampler
_dp
,
to_node_block
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
sampler_dp
=
Sampler
(
data_block_converter
,
graph
,
fanouts
)
sampler_dp
=
Sampler
(
data_block_converter
,
graph
,
fanouts
)
assert
len
(
list
(
sampler_dp
))
==
5
assert
len
(
list
(
sampler_dp
))
==
5
...
@@ -38,10 +38,10 @@ def test_SubgraphSampler_Link(labor):
...
@@ -38,10 +38,10 @@ def test_SubgraphSampler_Link(labor):
torch
.
arange
(
10
,
20
),
torch
.
arange
(
10
,
20
),
)
)
)
)
minibatch_dp
=
gb
.
Minibatch
Sampler
(
itemset
,
batch_size
=
2
)
item_sampler_dp
=
gb
.
Item
Sampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
minibatch
_dp
,
to_link_block
)
data_block_converter
=
Mapper
(
item_sampler
_dp
,
to_link_block
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
neighbor_dp
=
Sampler
(
data_block_converter
,
graph
,
fanouts
)
neighbor_dp
=
Sampler
(
data_block_converter
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
assert
len
(
list
(
neighbor_dp
))
==
5
...
@@ -65,10 +65,10 @@ def test_SubgraphSampler_Link_With_Negative(format, labor):
...
@@ -65,10 +65,10 @@ def test_SubgraphSampler_Link_With_Negative(format, labor):
torch
.
arange
(
10
,
20
),
torch
.
arange
(
10
,
20
),
)
)
)
)
minibatch_dp
=
gb
.
Minibatch
Sampler
(
itemset
,
batch_size
=
2
)
item_sampler_dp
=
gb
.
Item
Sampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
minibatch
_dp
,
to_link_block
)
data_block_converter
=
Mapper
(
item_sampler
_dp
,
to_link_block
)
negative_dp
=
gb
.
UniformNegativeSampler
(
negative_dp
=
gb
.
UniformNegativeSampler
(
data_block_converter
,
1
,
format
,
graph
data_block_converter
,
1
,
format
,
graph
)
)
...
@@ -119,10 +119,10 @@ def test_SubgraphSampler_Link_Hetero(labor):
...
@@ -119,10 +119,10 @@ def test_SubgraphSampler_Link_Hetero(labor):
}
}
)
)
minibatch_dp
=
gb
.
Minibatch
Sampler
(
itemset
,
batch_size
=
2
)
item_sampler_dp
=
gb
.
Item
Sampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
minibatch
_dp
,
to_link_block
)
data_block_converter
=
Mapper
(
item_sampler
_dp
,
to_link_block
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
neighbor_dp
=
Sampler
(
data_block_converter
,
graph
,
fanouts
)
neighbor_dp
=
Sampler
(
data_block_converter
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
assert
len
(
list
(
neighbor_dp
))
==
5
...
@@ -157,10 +157,10 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
...
@@ -157,10 +157,10 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
}
}
)
)
minibatch_dp
=
gb
.
Minibatch
Sampler
(
itemset
,
batch_size
=
2
)
item_sampler_dp
=
gb
.
Item
Sampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
minibatch
_dp
,
to_link_block
)
data_block_converter
=
Mapper
(
item_sampler
_dp
,
to_link_block
)
negative_dp
=
gb
.
UniformNegativeSampler
(
negative_dp
=
gb
.
UniformNegativeSampler
(
data_block_converter
,
1
,
format
,
graph
data_block_converter
,
1
,
format
,
graph
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment