Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f6dec359
Unverified
Commit
f6dec359
authored
Aug 23, 2023
by
peizhou001
Committed by
GitHub
Aug 23, 2023
Browse files
[Graphbolt]Subgraph sampler udf (#6129)
parent
4663cb0c
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
387 additions
and
72 deletions
+387
-72
python/dgl/graphbolt/impl/__init__.py
python/dgl/graphbolt/impl/__init__.py
+1
-0
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+111
-0
python/dgl/graphbolt/subgraph_sampler.py
python/dgl/graphbolt/subgraph_sampler.py
+99
-12
tests/python/pytorch/graphbolt/test_multi_process_dataloader.py
...python/pytorch/graphbolt/test_multi_process_dataloader.py
+21
-11
tests/python/pytorch/graphbolt/test_single_process_dataloader.py
...ython/pytorch/graphbolt/test_single_process_dataloader.py
+18
-17
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+137
-32
No files found.
python/dgl/graphbolt/impl/__init__.py
View file @
f6dec359
...
@@ -5,3 +5,4 @@ from .torch_based_feature_store import *
...
@@ -5,3 +5,4 @@ from .torch_based_feature_store import *
from
.csc_sampling_graph
import
*
from
.csc_sampling_graph
import
*
from
.sampled_subgraph_impl
import
*
from
.sampled_subgraph_impl
import
*
from
.uniform_negative_sampler
import
*
from
.uniform_negative_sampler
import
*
from
.neighbor_sampler
import
*
python/dgl/graphbolt/impl/neighbor_sampler.py
0 → 100644
View file @
f6dec359
"""Neighbor subgraph sampler for GraphBolt."""
from
..subgraph_sampler
import
SubgraphSampler
from
..utils
import
unique_and_compact_node_pairs
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
class
NeighborSampler
(
SubgraphSampler
):
"""
Neighbor sampler is responsible for sampling a subgraph from given data. It
returns an induced subgraph along with compacted information. In the
context of a node classification task, the neighbor sampler directly
utilizes the nodes provided as seed nodes. However, in scenarios involving
link prediction, the process needs another pre-peocess operation. That is,
gathering unique nodes from the given node pairs, encompassing both
positive and negative node pairs, and employs these nodes as the seed nodes
for subsequent steps.
"""
def
__init__
(
self
,
datapipe
,
graph
,
fanouts
,
replace
=
False
,
prob_name
=
None
,
):
"""
Initlization for a link neighbor subgraph sampler.
Parameters
----------
datapipe : DataPipe
The datapipe.
graph : CSCSamplingGraph
The graph on which to perform subgraph sampling.
fanouts: list[torch.Tensor]
The number of edges to be sampled for each node with or without
considering edge types. The length of this parameter implicitly
signifies the layer of sampling being conducted.
replace: bool
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
prob_name: str, optional
The name of an edge attribute used as the weights of sampling for
each node. This attribute tensor should contain (unnormalized)
probabilities corresponding to each neighboring edge of a node.
It must be a 1D floating-point or boolean tensor, with the number
of elements equalling the total number of edges.
Examples
-------
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> def to_link_block(data):
... block = gb.LinkPredictionBlock(node_pair=data)
... return block
...
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.from_csc(indptr, indices)
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler(
...item_set, batch_size=1,
...)
>>> data_block_converter = Mapper(minibatch_sampler, to_link_block)
>>> neg_sampler = gb.UniformNegativeSampler(
...data_block_converter, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...torch.LongTensor([15])]
>>> subgraph_sampler = gb.NeighborSampler(
...neg_sampler, graph, fanouts)
>>> for data in subgraph_sampler:
... print(data.compacted_node_pair)
... print(len(data.sampled_subgraphs))
(tensor([0, 0, 0]), tensor([1, 0, 2]))
3
(tensor([0, 0, 0]), tensor([1, 1, 1]))
3
"""
super
().
__init__
(
datapipe
)
self
.
fanouts
=
fanouts
self
.
replace
=
replace
self
.
prob_name
=
prob_name
self
.
graph
=
graph
def
_sample_sub_graphs
(
self
,
seeds
):
subgraphs
=
[]
num_layers
=
len
(
self
.
fanouts
)
for
hop
in
range
(
num_layers
):
subgraph
=
self
.
graph
.
sample_neighbors
(
seeds
,
self
.
fanouts
[
hop
],
self
.
replace
,
self
.
prob_name
,
)
reverse_row_node_ids
=
seeds
seeds
,
compacted_node_pairs
=
unique_and_compact_node_pairs
(
subgraph
.
node_pairs
,
seeds
)
subgraph
=
SampledSubgraphImpl
(
node_pairs
=
compacted_node_pairs
,
reverse_column_node_ids
=
seeds
,
reverse_row_node_ids
=
reverse_row_node_ids
,
)
subgraphs
.
insert
(
0
,
subgraph
)
return
seeds
,
subgraphs
python/dgl/graphbolt/subgraph_sampler.py
View file @
f6dec359
"""Subgraph samplers"""
"""Subgraph samplers"""
from
torchdata.datapipes.iter
import
Mapper
from
collections
import
defaultdict
from
typing
import
Dict
from
torchdata.datapipes.iter
import
Mapper
class
SubgraphSampler
(
Mapper
):
from
.link_prediction_block
import
LinkPredictionBlock
"""A subgraph sampler.
from
.node_classification_block
import
NodeClassificationBlock
from
.utils
import
unique_and_compact
It is an iterator equivalent to the following:
.. code:: python
class
SubgraphSampler
(
Mapper
):
"""A subgraph sampler used to sample a subgraph from a given set of nodes
from a larger graph."""
for data in datapipe:
def
__init__
(
yield sampler_func(data)
self
,
datapipe
,
):
"""
Initlization for a subgraph sampler.
Parameters
Parameters
----------
----------
datapipe : DataPipe
datapipe : DataPipe
The datapipe.
The datapipe.
fn : callable
The subgraph sampling function.
"""
"""
super
().
__init__
(
datapipe
,
self
.
_sample
)
def
_sample
(
self
,
data
):
if
isinstance
(
data
,
LinkPredictionBlock
):
(
seeds
,
data
.
compacted_node_pair
,
data
.
compacted_negative_head
,
data
.
compacted_negative_tail
,
)
=
self
.
_link_prediction_preprocess
(
data
)
elif
isinstance
(
data
,
NodeClassificationBlock
):
seeds
=
data
.
seed_node
else
:
raise
TypeError
(
f
"Unsupported type of data
{
data
}
."
)
data
.
input_nodes
,
data
.
sampled_subgraphs
=
self
.
_sample_sub_graphs
(
seeds
)
return
data
def
_link_prediction_preprocess
(
self
,
data
):
node_pair
=
data
.
node_pair
neg_src
,
neg_dst
=
data
.
negative_head
,
data
.
negative_tail
has_neg_src
=
neg_src
is
not
None
has_neg_dst
=
neg_dst
is
not
None
is_heterogeneous
=
isinstance
(
node_pair
,
Dict
)
if
is_heterogeneous
:
# Collect nodes from all types of input.
nodes
=
defaultdict
(
list
)
for
(
src_type
,
_
,
dst_type
),
(
src
,
dst
)
in
node_pair
.
items
():
nodes
[
src_type
].
append
(
src
)
nodes
[
dst_type
].
append
(
dst
)
if
has_neg_src
:
for
(
src_type
,
_
,
_
),
src
in
neg_src
.
items
():
nodes
[
src_type
].
append
(
src
.
view
(
-
1
))
if
has_neg_dst
:
for
(
_
,
_
,
dst_type
),
dst
in
neg_dst
.
items
():
nodes
[
dst_type
].
append
(
dst
.
view
(
-
1
))
# Unique and compact the collected nodes.
seeds
,
compacted
=
unique_and_compact
(
nodes
)
(
compacted_node_pair
,
compacted_negative_head
,
compacted_negative_tail
,
)
=
({},
{},
{})
# Map back in same order as collect.
for
etype
,
_
in
node_pair
.
items
():
src_type
,
_
,
dst_type
=
etype
src
=
compacted
[
src_type
].
pop
(
0
)
dst
=
compacted
[
dst_type
].
pop
(
0
)
compacted_node_pair
[
etype
]
=
(
src
,
dst
)
if
has_neg_src
:
for
etype
,
_
in
neg_src
.
items
():
compacted_negative_head
[
etype
]
=
compacted
[
etype
[
0
]].
pop
(
0
)
if
has_neg_dst
:
for
etype
,
_
in
neg_dst
.
items
():
compacted_negative_tail
[
etype
]
=
compacted
[
etype
[
2
]].
pop
(
0
)
else
:
# Collect nodes from all types of input.
nodes
=
list
(
node_pair
)
if
has_neg_src
:
nodes
.
append
(
neg_src
.
view
(
-
1
))
if
has_neg_dst
:
nodes
.
append
(
neg_dst
.
view
(
-
1
))
# Unique and compact the collected nodes.
seeds
,
compacted
=
unique_and_compact
(
nodes
)
# Map back in same order as collect.
compacted_node_pair
=
tuple
(
compacted
[:
2
])
compacted
=
compacted
[
2
:]
if
has_neg_src
:
compacted_negative_head
=
compacted
.
pop
(
0
)
if
has_neg_dst
:
compacted_negative_tail
=
compacted
.
pop
(
0
)
return
(
seeds
,
compacted_node_pair
,
compacted_negative_head
if
has_neg_src
else
None
,
compacted_negative_tail
if
has_neg_dst
else
None
,
)
def
_sample_sub_graphs
(
self
,
seeds
):
raise
NotImplementedError
tests/python/pytorch/graphbolt/test_multi_process_dataloader.py
View file @
f6dec359
import
os
import
unittest
from
functools
import
partial
from
functools
import
partial
import
backend
as
F
import
backend
as
F
...
@@ -5,12 +7,17 @@ import dgl
...
@@ -5,12 +7,17 @@ import dgl
import
dgl.graphbolt
import
dgl.graphbolt
import
gb_test_utils
import
gb_test_utils
import
torch
import
torch
from
torchdata.datapipes.iter
import
Mapper
def
sampler_func
(
graph
,
data
):
def
to_node_block
(
data
):
seeds
=
data
block
=
dgl
.
graphbolt
.
NodeClassificationBlock
(
seed_node
=
data
)
sampler
=
dgl
.
dataloading
.
NeighborSampler
([
2
,
2
])
return
block
return
sampler
.
sample
(
graph
,
seeds
)
def
to_tuple
(
data
):
output_nodes
=
data
.
sampled_subgraphs
[
-
1
].
reverse_column_node_ids
return
data
.
input_nodes
,
output_nodes
,
data
.
sampled_subgraphs
def
fetch_func
(
features
,
labels
,
data
):
def
fetch_func
(
features
,
labels
,
data
):
...
@@ -20,23 +27,26 @@ def fetch_func(features, labels, data):
...
@@ -20,23 +27,26 @@ def fetch_func(features, labels, data):
return
input_features
,
output_labels
,
adjs
return
input_features
,
output_labels
,
adjs
@
unittest
.
skipIf
(
os
.
name
==
"nt"
,
reason
=
"Do not support windows yet"
)
# TODO (peizhou): Will enable windows test once CSCSamplingraph is pickleable.
def
test_DataLoader
():
def
test_DataLoader
():
N
=
40
N
=
40
B
=
4
B
=
4
itemset
=
dgl
.
graphbolt
.
ItemSet
(
torch
.
arange
(
N
))
itemset
=
dgl
.
graphbolt
.
ItemSet
(
torch
.
arange
(
N
))
# TODO(BarclayII): temporarily using DGLGraph. Should test using
graph
=
gb_test_utils
.
rand_csc_graph
(
200
,
0.15
)
# GraphBolt's storage as well once issue #5953 is resolved.
graph
=
dgl
.
add_reverse_edges
(
dgl
.
rand_graph
(
200
,
6000
))
features
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randn
(
200
,
4
))
features
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randn
(
200
,
4
))
labels
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randint
(
0
,
10
,
(
200
,)))
labels
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randint
(
0
,
10
,
(
200
,)))
minibatch_sampler
=
dgl
.
graphbolt
.
MinibatchSampler
(
itemset
,
batch_size
=
B
)
minibatch_sampler
=
dgl
.
graphbolt
.
MinibatchSampler
(
itemset
,
batch_size
=
B
)
subgraph_sampler
=
dgl
.
graphbolt
.
SubgraphSampler
(
block_converter
=
Mapper
(
minibatch_sampler
,
to_node_block
)
minibatch_sampler
,
subgraph_sampler
=
dgl
.
graphbolt
.
NeighborSampler
(
partial
(
sampler_func
,
graph
),
block_converter
,
graph
,
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
2
)],
)
)
tuple_converter
=
Mapper
(
subgraph_sampler
,
to_tuple
)
feature_fetcher
=
dgl
.
graphbolt
.
FeatureFetcher
(
feature_fetcher
=
dgl
.
graphbolt
.
FeatureFetcher
(
subgraph_sampl
er
,
tuple_convert
er
,
partial
(
fetch_func
,
features
,
labels
),
partial
(
fetch_func
,
features
,
labels
),
)
)
device_transferrer
=
dgl
.
graphbolt
.
CopyTo
(
feature_fetcher
,
F
.
ctx
())
device_transferrer
=
dgl
.
graphbolt
.
CopyTo
(
feature_fetcher
,
F
.
ctx
())
...
...
tests/python/pytorch/graphbolt/test_single_process_dataloader.py
View file @
f6dec359
...
@@ -3,6 +3,17 @@ import dgl
...
@@ -3,6 +3,17 @@ import dgl
import
dgl.graphbolt
import
dgl.graphbolt
import
gb_test_utils
import
gb_test_utils
import
torch
import
torch
from
torchdata.datapipes.iter
import
Mapper
def
to_node_block
(
data
):
block
=
dgl
.
graphbolt
.
NodeClassificationBlock
(
seed_node
=
data
)
return
block
def
to_tuple
(
data
):
output_nodes
=
data
.
sampled_subgraphs
[
-
1
].
reverse_column_node_ids
return
data
.
input_nodes
,
output_nodes
,
data
.
sampled_subgraphs
def
test_DataLoader
():
def
test_DataLoader
():
...
@@ -13,19 +24,6 @@ def test_DataLoader():
...
@@ -13,19 +24,6 @@ def test_DataLoader():
features
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randn
(
200
,
4
))
features
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randn
(
200
,
4
))
labels
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randint
(
0
,
10
,
(
200
,)))
labels
=
dgl
.
graphbolt
.
TorchBasedFeature
(
torch
.
randint
(
0
,
10
,
(
200
,)))
def
sampler_func
(
data
):
adjs
=
[]
seeds
=
data
for
hop
in
range
(
2
):
sg
=
graph
.
sample_neighbors
(
seeds
,
torch
.
LongTensor
([
2
]))
seeds
=
sg
.
node_pairs
[
0
]
adjs
.
insert
(
0
,
sg
)
input_nodes
=
seeds
output_nodes
=
data
return
input_nodes
,
output_nodes
,
adjs
def
fetch_func
(
data
):
def
fetch_func
(
data
):
input_nodes
,
output_nodes
,
adjs
=
data
input_nodes
,
output_nodes
,
adjs
=
data
input_features
=
features
.
read
(
input_nodes
)
input_features
=
features
.
read
(
input_nodes
)
...
@@ -33,11 +31,14 @@ def test_DataLoader():
...
@@ -33,11 +31,14 @@ def test_DataLoader():
return
input_features
,
output_labels
,
adjs
return
input_features
,
output_labels
,
adjs
minibatch_sampler
=
dgl
.
graphbolt
.
MinibatchSampler
(
itemset
,
batch_size
=
B
)
minibatch_sampler
=
dgl
.
graphbolt
.
MinibatchSampler
(
itemset
,
batch_size
=
B
)
subgraph_sampler
=
dgl
.
graphbolt
.
SubgraphSampler
(
block_converter
=
Mapper
(
minibatch_sampler
,
to_node_block
)
minibatch_sampler
,
subgraph_sampler
=
dgl
.
graphbolt
.
NeighborSampler
(
sampler_func
,
block_converter
,
graph
,
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
2
)],
)
)
feature_fetcher
=
dgl
.
graphbolt
.
FeatureFetcher
(
subgraph_sampler
,
fetch_func
)
tuple_converter
=
Mapper
(
subgraph_sampler
,
to_tuple
)
feature_fetcher
=
dgl
.
graphbolt
.
FeatureFetcher
(
tuple_converter
,
fetch_func
)
device_transferrer
=
dgl
.
graphbolt
.
CopyTo
(
feature_fetcher
,
F
.
ctx
())
device_transferrer
=
dgl
.
graphbolt
.
CopyTo
(
feature_fetcher
,
F
.
ctx
())
dataloader
=
dgl
.
graphbolt
.
SingleProcessDataLoader
(
device_transferrer
)
dataloader
=
dgl
.
graphbolt
.
SingleProcessDataLoader
(
device_transferrer
)
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
f6dec359
import
dgl
import
dgl.graphbolt
as
gb
import
dgl.graphbolt
import
gb_test_utils
import
gb_test_utils
import
pytest
import
pytest
import
torch
import
torch
import
torchdata.datapipes
as
dp
import
torchdata.datapipes
as
dp
from
torchdata.datapipes.iter
import
Mapper
def
get_graphbolt_sampler_func
():
def
to_node_block
(
data
):
block
=
gb
.
NodeClassificationBlock
(
seed_node
=
data
)
return
block
def
test_SubgraphSampler_Node
():
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
))
minibatch_dp
=
gb
.
MinibatchSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
minibatch_dp
,
to_node_block
)
sampler_dp
=
gb
.
NeighborSampler
(
data_block_converter
,
graph
,
fanouts
)
assert
len
(
list
(
sampler_dp
))
==
5
def
sampler_func
(
data
):
adjs
=
[]
seeds
=
data
for
hop
in
range
(
2
):
def
to_link_block
(
data
):
sg
=
graph
.
sample_neighbors
(
seeds
,
torch
.
LongTensor
([
2
]))
block
=
gb
.
LinkPredictionBlock
(
node_pair
=
data
)
seeds
=
sg
.
node_pairs
[
0
]
return
block
adjs
.
insert
(
0
,
sg
)
return
seeds
,
data
,
adjs
return
sampler_func
def
test_SubgraphSampler_Link
():
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
itemset
=
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
10
),
torch
.
arange
(
10
,
20
),
)
)
minibatch_dp
=
gb
.
MinibatchSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
minibatch_dp
,
to_link_block
)
neighbor_dp
=
gb
.
NeighborSampler
(
data_block_converter
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
def
get_dgl_sampler_func
():
graph
=
dgl
.
add_reverse_edges
(
dgl
.
rand_graph
(
20
,
60
))
sampler
=
dgl
.
dataloading
.
NeighborSampler
([
2
,
2
])
def
sampler_func
(
data
):
@
pytest
.
mark
.
parametrize
(
return
sampler
.
sample
(
graph
,
data
)
"format"
,
[
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
gb
.
LinkPredictionEdgeFormat
.
CONDITIONED
,
gb
.
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
,
gb
.
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
,
],
)
def
test_SubgraphSampler_Link_With_Negative
(
format
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
itemset
=
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
10
),
torch
.
arange
(
10
,
20
),
)
)
minibatch_dp
=
gb
.
MinibatchSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
minibatch_dp
,
to_link_block
)
negative_dp
=
gb
.
UniformNegativeSampler
(
data_block_converter
,
1
,
format
,
graph
)
neighbor_dp
=
gb
.
NeighborSampler
(
negative_dp
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
return
sampler_func
def
get_hetero_graph
():
# COO graph:
# [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
# [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{(
"n1"
,
"e1"
,
"n2"
):
0
,
(
"n2"
,
"e2"
,
"n1"
):
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
indptr
=
torch
.
LongTensor
([
0
,
2
,
4
,
6
,
8
,
10
])
indices
=
torch
.
LongTensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
0
,
1
])
type_per_edge
=
torch
.
LongTensor
([
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
])
node_type_offset
=
torch
.
LongTensor
([
0
,
2
,
5
])
return
gb
.
from_csc
(
indptr
,
indices
,
node_type_offset
=
node_type_offset
,
type_per_edge
=
type_per_edge
,
metadata
=
metadata
,
)
def
get_graphbolt_minibatch_dp
():
itemset
=
dgl
.
graphbolt
.
ItemSet
(
torch
.
arange
(
10
))
return
dgl
.
graphbolt
.
MinibatchSampler
(
itemset
,
batch_size
=
2
)
def
test_SubgraphSampler_Link_Hetero
():
graph
=
get_hetero_graph
()
itemset
=
gb
.
ItemSetDict
(
{
(
"n1"
,
"e1"
,
"n2"
):
gb
.
ItemSet
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
]),
torch
.
LongTensor
([
0
,
2
,
0
,
1
]),
)
),
(
"n2"
,
"e2"
,
"n1"
):
gb
.
ItemSet
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
2
,
2
]),
torch
.
LongTensor
([
0
,
1
,
1
,
0
,
0
,
1
]),
)
),
}
)
def
get_torchdata_minibatch_dp
():
minibatch_dp
=
gb
.
MinibatchSampler
(
itemset
,
batch_size
=
2
)
minibatch_dp
=
dp
.
map
.
SequenceWrapper
(
torch
.
arange
(
10
)).
batch
(
2
)
num_layer
=
2
minibatch_dp
=
minibatch_dp
.
to_iter_datapipe
().
collate
()
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
return
minibatch_dp
data_block_converter
=
Mapper
(
minibatch_dp
,
to_link_block
)
neighbor_dp
=
gb
.
NeighborSampler
(
data_block_converter
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"sampler_func"
,
[
get_graphbolt_sampler_func
(),
get_dgl_sampler_func
()]
"format"
,
)
[
@
pytest
.
mark
.
parametrize
(
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
"minibatch_dp"
,
[
get_graphbolt_minibatch_dp
(),
get_torchdata_minibatch_dp
()]
gb
.
LinkPredictionEdgeFormat
.
CONDITIONED
,
gb
.
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
,
gb
.
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
,
],
)
)
def
test_SubgraphSampler
(
minibatch_dp
,
sampler_func
):
def
test_SubgraphSampler_Link_Hetero_With_Negative
(
format
):
sampler_dp
=
dgl
.
graphbolt
.
SubgraphSampler
(
minibatch_dp
,
sampler_func
)
graph
=
get_hetero_graph
()
assert
len
(
list
(
sampler_dp
))
==
5
itemset
=
gb
.
ItemSetDict
(
{
(
"n1"
,
"e1"
,
"n2"
):
gb
.
ItemSet
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
]),
torch
.
LongTensor
([
0
,
2
,
0
,
1
]),
)
),
(
"n2"
,
"e2"
,
"n1"
):
gb
.
ItemSet
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
2
,
2
]),
torch
.
LongTensor
([
0
,
1
,
1
,
0
,
0
,
1
]),
)
),
}
)
minibatch_dp
=
gb
.
MinibatchSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
minibatch_dp
,
to_link_block
)
negative_dp
=
gb
.
UniformNegativeSampler
(
data_block_converter
,
1
,
format
,
graph
)
neighbor_dp
=
gb
.
NeighborSampler
(
negative_dp
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment