Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
86f739b3
"projects/vscode:/vscode.git/clone" did not exist on "4ecc9ea89d55b51c6ad66996ff0edd013ded0815"
Unverified
Commit
86f739b3
authored
Sep 04, 2023
by
peizhou001
Committed by
GitHub
Sep 04, 2023
Browse files
[Graphbolt] Change data_block to mini_batch (#6256)
parent
f281959a
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
139 additions
and
140 deletions
+139
-140
python/dgl/graphbolt/__init__.py
python/dgl/graphbolt/__init__.py
+1
-1
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+12
-10
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+6
-26
python/dgl/graphbolt/negative_sampler.py
python/dgl/graphbolt/negative_sampler.py
+32
-29
python/dgl/graphbolt/subgraph_sampler.py
python/dgl/graphbolt/subgraph_sampler.py
+20
-15
tests/python/pytorch/graphbolt/gb_test_utils.py
tests/python/pytorch/graphbolt/gb_test_utils.py
+8
-3
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
+20
-19
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+10
-10
tests/python/pytorch/graphbolt/test_multi_process_dataloader.py
...python/pytorch/graphbolt/test_multi_process_dataloader.py
+4
-3
tests/python/pytorch/graphbolt/test_single_process_dataloader.py
...ython/pytorch/graphbolt/test_single_process_dataloader.py
+4
-7
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+22
-17
No files found.
python/dgl/graphbolt/__init__.py
View file @
86f739b3
...
...
@@ -6,7 +6,7 @@ import torch
from
.._ffi
import
libinfo
from
.base
import
*
from
.
data_block
import
*
from
.
minibatch
import
*
from
.data_format
import
*
from
.dataloader
import
*
from
.dataset
import
*
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
86f739b3
...
...
@@ -53,9 +53,9 @@ class NeighborSampler(SubgraphSampler):
-------
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> def
to_link_block
(data):
...
block
= gb.
L
in
kPredictionBlock
(node_pair=data)
... return
block
>>> def
minibatch_link_collator
(data):
...
minibatch
= gb.
M
in
iBatch
(node_pair=data)
... return
minibatch
...
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
...
...
@@ -67,9 +67,10 @@ class NeighborSampler(SubgraphSampler):
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> data_block_converter = Mapper(item_sampler, to_link_block)
>>> minibatch_converter = Mapper(item_sampler,
...minibatch_link_collator)
>>> neg_sampler = gb.UniformNegativeSampler(
...
data_block
_converter, 2, data_format, graph)
...
minibatch
_converter, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...torch.LongTensor([15])]
>>> subgraph_sampler = gb.NeighborSampler(
...
...
@@ -164,9 +165,9 @@ class LayerNeighborSampler(NeighborSampler):
-------
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> def
to_link_block
(data):
...
block
= gb.
L
in
kPredictionBlock
(node_pair=data)
... return
block
>>> def
minibatch_link_collator
(data):
...
minibatch
= gb.
M
in
iBatch
(node_pair=data)
... return
minibatch
...
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
...
...
@@ -178,9 +179,10 @@ class LayerNeighborSampler(NeighborSampler):
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> data_block_converter = Mapper(item_sampler, to_link_block)
>>> minibatch_converter = Mapper(item_sampler,
...minibatch_link_collator)
>>> neg_sampler = gb.UniformNegativeSampler(
...
data_block
_converter, 2, data_format, graph)
...
minibatch
_converter, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...torch.LongTensor([15])]
>>> subgraph_sampler = gb.LayerNeighborSampler(
...
...
python/dgl/graphbolt/
data_block
.py
→
python/dgl/graphbolt/
minibatch
.py
View file @
86f739b3
...
...
@@ -7,11 +7,11 @@ import torch
from
.sampled_subgraph
import
SampledSubgraph
__all__
=
[
"
DataBlock"
,
"NodeClassificationBlock"
,
"LinkPredictionBlock
"
]
__all__
=
[
"
MiniBatch
"
]
@
dataclass
class
DataBlock
:
class
MiniBatch
:
r
"""A composite data class for data structure in the graphbolt. It is
designed to facilitate the exchange of data among different components
involved in processing data. The purpose of this class is to unify the
...
...
@@ -52,12 +52,6 @@ class DataBlock:
value should be corresponding heterogeneous node id.
"""
@
dataclass
class
NodeClassificationBlock
(
DataBlock
):
r
"""A subclass of 'UnifiedDataStruct', specialized for handling node level
tasks."""
seed_node
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""
Representation of seed nodes used for sampling in the graph.
...
...
@@ -69,17 +63,12 @@ class NodeClassificationBlock(DataBlock):
label
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""
Labels associated with seed nodes in the graph.
- If `label` is a tensor: It indicates the graph is homogeneous.
- If `label` is a dictionary: The keys should be node type and the
value should be corresponding node labels to given 'seed_node'.
- If `label` is a tensor: It indicates the graph is homogeneous. The value
should be corresponding labels to given 'seed_node' or 'node_pair'.
- If `label` is a dictionary: The keys should be node or edge type and the
value should be corresponding labels to given 'seed_node' or 'node_pair'.
"""
@
dataclass
class
LinkPredictionBlock
(
DataBlock
):
r
"""A subclass of 'UnifiedDataStruct', specialized for handling edge level
tasks."""
node_pair
:
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
...
...
@@ -93,15 +82,6 @@ class LinkPredictionBlock(DataBlock):
type.
"""
label
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""
Labels associated with the link prediction task.
- If `label` is a tensor: It indicates a homogeneous graph. The value are
edge labels corresponding to given 'node_pair'.
- If `label` is a dictionary: The keys should be edge type, and the value
should correspond to given 'node_pair'.
"""
negative_head
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""
Representation of negative samples for the head nodes in the link
...
...
python/dgl/graphbolt/negative_sampler.py
View file @
86f739b3
...
...
@@ -30,49 +30,50 @@ class NegativeSampler(Mapper):
negative_ratio : int
The proportion of negative samples to positive samples.
output_format : LinkPredictionEdgeFormat
Determines the edge format of the output
data
.
Determines the edge format of the output
minibatch
.
"""
super
().
__init__
(
datapipe
,
self
.
_sample
)
assert
negative_ratio
>
0
,
"Negative_ratio should be positive Integer."
self
.
negative_ratio
=
negative_ratio
self
.
output_format
=
output_format
def
_sample
(
self
,
data
):
def
_sample
(
self
,
minibatch
):
"""
Generate a mix of positive and negative samples.
Parameters
----------
data
:
L
in
kPredictionBlock
An instance of '
L
in
kPredictionBlock
' class requires the 'node_pair'
field.
This function is responsible for generating negative edges
minibatch
:
M
in
iBatch
An instance of '
M
in
iBatch
' class requires the 'node_pair'
field.
This function is responsible for generating negative edges
corresponding to the positive edges defined by the 'node_pair'. In
cases where negative edges already exist, this function will
overwrite them.
Returns
-------
L
in
kPredictionBlock
An instance of '
L
in
kPredictionBlock
' encompasses both positive and
negative
samples.
M
in
iBatch
An instance of '
M
in
iBatch
' encompasses both positive and
negative
samples.
"""
node_pairs
=
data
.
node_pair
node_pairs
=
minibatch
.
node_pair
assert
node_pairs
is
not
None
if
isinstance
(
node_pairs
,
Mapping
):
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
data
.
label
=
{}
minibatch
.
label
=
{}
else
:
data
.
negative_head
,
data
.
negative_tail
=
{},
{}
minibatch
.
negative_head
,
minibatch
.
negative_tail
=
{},
{}
for
etype
,
pos_pairs
in
node_pairs
.
items
():
self
.
_collate
(
data
,
self
.
_sample_with_etype
(
pos_pairs
,
etype
),
etype
minibatch
,
self
.
_sample_with_etype
(
pos_pairs
,
etype
),
etype
)
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
:
data
.
negative_tail
=
None
minibatch
.
negative_tail
=
None
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
:
data
.
negative_head
=
None
minibatch
.
negative_head
=
None
else
:
self
.
_collate
(
data
,
self
.
_sample_with_etype
(
node_pairs
))
return
data
self
.
_collate
(
minibatch
,
self
.
_sample_with_etype
(
node_pairs
))
return
minibatch
def
_sample_with_etype
(
self
,
node_pairs
,
etype
=
None
):
"""Generate negative pairs for a given etype form positive pairs
...
...
@@ -94,13 +95,13 @@ class NegativeSampler(Mapper):
"""
raise
NotImplementedError
def
_collate
(
self
,
data
,
neg_pairs
,
etype
=
None
):
"""Collates positive and negative samples into
data
.
def
_collate
(
self
,
minibatch
,
neg_pairs
,
etype
=
None
):
"""Collates positive and negative samples into
minibatch
.
Parameters
----------
data
:
L
in
kPredictionBlock
The input
data
, which contains positive node pairs, will be filled
minibatch
:
M
in
iBatch
The input
minibatch
, which contains positive node pairs, will be filled
with negative information in this function.
neg_pairs : Tuple[Tensor, Tensor]
A tuple of tensors represents source-destination node pairs of
...
...
@@ -110,7 +111,9 @@ class NegativeSampler(Mapper):
Canonical edge type.
"""
pos_src
,
pos_dst
=
(
data
.
node_pair
[
etype
]
if
etype
is
not
None
else
data
.
node_pair
minibatch
.
node_pair
[
etype
]
if
etype
is
not
None
else
minibatch
.
node_pair
)
neg_src
,
neg_dst
=
neg_pairs
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
...
...
@@ -120,11 +123,11 @@ class NegativeSampler(Mapper):
dst
=
torch
.
cat
([
pos_dst
,
neg_dst
])
label
=
torch
.
cat
([
pos_label
,
neg_label
])
if
etype
is
not
None
:
data
.
node_pair
[
etype
]
=
(
src
,
dst
)
data
.
label
[
etype
]
=
label
minibatch
.
node_pair
[
etype
]
=
(
src
,
dst
)
minibatch
.
label
[
etype
]
=
label
else
:
data
.
node_pair
=
(
src
,
dst
)
data
.
label
=
label
minibatch
.
node_pair
=
(
src
,
dst
)
minibatch
.
label
=
label
else
:
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
CONDITIONED
:
neg_src
=
neg_src
.
view
(
-
1
,
self
.
negative_ratio
)
...
...
@@ -144,8 +147,8 @@ class NegativeSampler(Mapper):
f
"Unsupported output format
{
self
.
output_format
}
."
)
if
etype
is
not
None
:
data
.
negative_head
[
etype
]
=
neg_src
data
.
negative_tail
[
etype
]
=
neg_dst
minibatch
.
negative_head
[
etype
]
=
neg_src
minibatch
.
negative_tail
[
etype
]
=
neg_dst
else
:
data
.
negative_head
=
neg_src
data
.
negative_tail
=
neg_dst
minibatch
.
negative_head
=
neg_src
minibatch
.
negative_tail
=
neg_dst
python/dgl/graphbolt/subgraph_sampler.py
View file @
86f739b3
...
...
@@ -6,7 +6,6 @@ from typing import Dict
from
torchdata.datapipes.iter
import
Mapper
from
.base
import
etype_str_to_tuple
from
.data_block
import
LinkPredictionBlock
,
NodeClassificationBlock
from
.utils
import
unique_and_compact
...
...
@@ -28,24 +27,30 @@ class SubgraphSampler(Mapper):
"""
super
().
__init__
(
datapipe
,
self
.
_sample
)
def
_sample
(
self
,
data
):
if
isinstance
(
data
,
LinkPredictionBlock
)
:
def
_sample
(
self
,
minibatch
):
if
minibatch
.
node_pair
is
not
None
:
(
seeds
,
data
.
compacted_node_pair
,
data
.
compacted_negative_head
,
data
.
compacted_negative_tail
,
)
=
self
.
_
link_prediction
_preprocess
(
data
)
elif
isinstance
(
data
,
NodeClassificationBlock
)
:
seeds
=
data
.
seed_node
minibatch
.
compacted_node_pair
,
minibatch
.
compacted_negative_head
,
minibatch
.
compacted_negative_tail
,
)
=
self
.
_
node_pair
_preprocess
(
minibatch
)
elif
minibatch
.
seed_node
is
not
None
:
seeds
=
minibatch
.
seed_node
else
:
raise
TypeError
(
f
"Unsupported type of data
{
data
}
."
)
data
.
input_nodes
,
data
.
sampled_subgraphs
=
self
.
_sample_subgraphs
(
seeds
)
return
data
raise
ValueError
(
f
"Invalid minibatch
{
minibatch
}
: Either 'node_pair' or
\
'seed_node' should have a value."
)
(
minibatch
.
input_nodes
,
minibatch
.
sampled_subgraphs
,
)
=
self
.
_sample_subgraphs
(
seeds
)
return
minibatch
def
_
link_prediction
_preprocess
(
self
,
data
):
node_pair
=
data
.
node_pair
neg_src
,
neg_dst
=
data
.
negative_head
,
data
.
negative_tail
def
_
node_pair
_preprocess
(
self
,
minibatch
):
node_pair
=
minibatch
.
node_pair
neg_src
,
neg_dst
=
minibatch
.
negative_head
,
minibatch
.
negative_tail
has_neg_src
=
neg_src
is
not
None
has_neg_dst
=
neg_dst
is
not
None
is_heterogeneous
=
isinstance
(
node_pair
,
Dict
)
...
...
tests/python/pytorch/graphbolt/gb_test_utils.py
View file @
86f739b3
...
...
@@ -8,9 +8,14 @@ import scipy.sparse as sp
import
torch
def
to_node_block
(
data
):
block
=
gb
.
NodeClassificationBlock
(
seed_node
=
data
)
return
block
def
minibatch_node_collator
(
data
):
minibatch
=
gb
.
MiniBatch
(
seed_node
=
data
)
return
minibatch
def
minibatch_link_collator
(
data
):
minibatch
=
gb
.
MiniBatch
(
node_pair
=
data
)
return
minibatch
def
rand_csc_graph
(
N
,
density
):
...
...
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
View file @
86f739b3
...
...
@@ -5,10 +5,6 @@ import torch
from
torchdata.datapipes.iter
import
Mapper
def
to_data_block
(
data
):
return
gb
.
LinkPredictionBlock
(
node_pair
=
data
)
@
pytest
.
mark
.
parametrize
(
"negative_ratio"
,
[
1
,
5
,
10
,
20
])
def
test_NegativeSampler_Independent_Format
(
negative_ratio
):
# Construct CSCSamplingGraph.
...
...
@@ -22,10 +18,12 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
)
batch_size
=
10
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
)
data_block_converter
=
Mapper
(
item_sampler
,
to_data_block
)
minibatch_converter
=
Mapper
(
item_sampler
,
gb_test_utils
.
minibatch_link_collator
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
data_block
_converter
,
minibatch
_converter
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
graph
,
...
...
@@ -55,10 +53,12 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
)
batch_size
=
10
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
)
data_block_converter
=
Mapper
(
item_sampler
,
to_data_block
)
minibatch_converter
=
Mapper
(
item_sampler
,
gb_test_utils
.
minibatch_link_collator
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
data_block
_converter
,
minibatch
_converter
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
CONDITIONED
,
graph
,
...
...
@@ -91,10 +91,12 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
)
batch_size
=
10
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
)
data_block_converter
=
Mapper
(
item_sampler
,
to_data_block
)
minibatch_converter
=
Mapper
(
item_sampler
,
gb_test_utils
.
minibatch_link_collator
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
data_block
_converter
,
minibatch
_converter
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
,
graph
,
...
...
@@ -125,10 +127,12 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
)
batch_size
=
10
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
)
data_block_converter
=
Mapper
(
item_sampler
,
to_data_block
)
minibatch_converter
=
Mapper
(
item_sampler
,
gb_test_utils
.
minibatch_link_collator
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
data_block
_converter
,
minibatch
_converter
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
,
graph
,
...
...
@@ -166,11 +170,6 @@ def get_hetero_graph():
)
def
to_link_block
(
data
):
block
=
gb
.
LinkPredictionBlock
(
node_pair
=
data
)
return
block
@
pytest
.
mark
.
parametrize
(
"format"
,
[
...
...
@@ -200,8 +199,10 @@ def test_NegativeSampler_Hetero_Data(format):
)
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
data_block_converter
=
Mapper
(
item_sampler_dp
,
to_link_block
)
minibatch_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
minibatch_link_collator
)
negative_dp
=
gb
.
UniformNegativeSampler
(
data_block
_converter
,
1
,
format
,
graph
minibatch
_converter
,
1
,
format
,
graph
)
assert
len
(
list
(
negative_dp
))
==
5
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
86f739b3
...
...
@@ -19,8 +19,10 @@ def test_FeatureFetcher_homo():
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
to_node_block
)
sampler_dp
=
gb
.
NeighborSampler
(
data_block_converter
,
graph
,
fanouts
)
minibatch_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
minibatch_node_collator
)
sampler_dp
=
gb
.
NeighborSampler
(
minibatch_converter
,
graph
,
fanouts
)
fetcher_dp
=
gb
.
FeatureFetcher
(
sampler_dp
,
feature_store
,
[
"a"
],
[
"b"
])
assert
len
(
list
(
fetcher_dp
))
==
5
...
...
@@ -40,9 +42,7 @@ def test_FeatureFetcher_with_edges_homo():
reverse_edge_ids
=
torch
.
randint
(
0
,
graph
.
num_edges
,
(
10
,)),
)
)
data
=
gb
.
NodeClassificationBlock
(
input_nodes
=
seeds
,
sampled_subgraphs
=
subgraphs
)
data
=
gb
.
MiniBatch
(
input_nodes
=
seeds
,
sampled_subgraphs
=
subgraphs
)
return
data
features
=
{}
...
...
@@ -106,8 +106,10 @@ def test_FeatureFetcher_hetero():
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
to_node_block
)
sampler_dp
=
gb
.
NeighborSampler
(
data_block_converter
,
graph
,
fanouts
)
minibatch_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
minibatch_node_collator
)
sampler_dp
=
gb
.
NeighborSampler
(
minibatch_converter
,
graph
,
fanouts
)
fetcher_dp
=
gb
.
FeatureFetcher
(
sampler_dp
,
feature_store
,
{
"n1"
:
[
"a"
],
"n2"
:
[
"a"
]}
)
...
...
@@ -132,9 +134,7 @@ def test_FeatureFetcher_with_edges_hetero():
reverse_edge_ids
=
reverse_edge_ids
,
)
)
data
=
gb
.
NodeClassificationBlock
(
input_nodes
=
seeds
,
sampled_subgraphs
=
subgraphs
)
data
=
gb
.
MiniBatch
(
input_nodes
=
seeds
,
sampled_subgraphs
=
subgraphs
)
return
data
features
=
{}
...
...
tests/python/pytorch/graphbolt/test_multi_process_dataloader.py
View file @
86f739b3
import
os
import
unittest
from
functools
import
partial
import
backend
as
F
...
...
@@ -23,9 +22,11 @@ def test_DataLoader():
feature_store
=
dgl
.
graphbolt
.
BasicFeatureStore
(
features
)
item_sampler
=
dgl
.
graphbolt
.
ItemSampler
(
itemset
,
batch_size
=
B
)
block_converter
=
Mapper
(
item_sampler
,
gb_test_utils
.
to_node_block
)
minibatch_converter
=
Mapper
(
item_sampler
,
gb_test_utils
.
minibatch_node_collator
)
subgraph_sampler
=
dgl
.
graphbolt
.
NeighborSampler
(
block
_converter
,
minibatch
_converter
,
graph
,
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
2
)],
)
...
...
tests/python/pytorch/graphbolt/test_single_process_dataloader.py
View file @
86f739b3
...
...
@@ -7,11 +7,6 @@ import torch
from
torchdata.datapipes.iter
import
Mapper
def
to_node_block
(
data
):
block
=
dgl
.
graphbolt
.
NodeClassificationBlock
(
seed_node
=
data
)
return
block
def
test_DataLoader
():
N
=
32
B
=
4
...
...
@@ -25,9 +20,11 @@ def test_DataLoader():
feature_store
=
dgl
.
graphbolt
.
BasicFeatureStore
(
features
)
item_sampler
=
dgl
.
graphbolt
.
ItemSampler
(
itemset
,
batch_size
=
B
)
block_converter
=
Mapper
(
item_sampler
,
to_node_block
)
minibatch_converter
=
Mapper
(
item_sampler
,
gb_test_utils
.
minibatch_node_collator
)
subgraph_sampler
=
dgl
.
graphbolt
.
NeighborSampler
(
block
_converter
,
minibatch
_converter
,
graph
,
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
2
)],
)
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
86f739b3
...
...
@@ -6,11 +6,6 @@ import torchdata.datapipes as dp
from
torchdata.datapipes.iter
import
Mapper
def
to_node_block
(
data
):
block
=
gb
.
NodeClassificationBlock
(
seed_node
=
data
)
return
block
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Node
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
...
...
@@ -18,14 +13,16 @@ def test_SubgraphSampler_Node(labor):
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
item_sampler_dp
,
to_node_block
)
minibatch_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
minibatch_node_collator
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
sampler_dp
=
Sampler
(
data_block
_converter
,
graph
,
fanouts
)
sampler_dp
=
Sampler
(
minibatch
_converter
,
graph
,
fanouts
)
assert
len
(
list
(
sampler_dp
))
==
5
def
to_link_b
lock
(
data
):
block
=
gb
.
L
in
kPredictionBlock
(
node_pair
=
data
)
def
to_link_b
atch
(
data
):
block
=
gb
.
M
in
iBatch
(
node_pair
=
data
)
return
block
...
...
@@ -41,9 +38,11 @@ def test_SubgraphSampler_Link(labor):
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
item_sampler_dp
,
to_link_block
)
minibatch_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
minibatch_link_collator
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
neighbor_dp
=
Sampler
(
data_block
_converter
,
graph
,
fanouts
)
neighbor_dp
=
Sampler
(
minibatch
_converter
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
...
...
@@ -68,9 +67,11 @@ def test_SubgraphSampler_Link_With_Negative(format, labor):
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
item_sampler_dp
,
to_link_block
)
minibatch_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
minibatch_link_collator
)
negative_dp
=
gb
.
UniformNegativeSampler
(
data_block
_converter
,
1
,
format
,
graph
minibatch
_converter
,
1
,
format
,
graph
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
neighbor_dp
=
Sampler
(
negative_dp
,
graph
,
fanouts
)
...
...
@@ -122,9 +123,11 @@ def test_SubgraphSampler_Link_Hetero(labor):
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
item_sampler_dp
,
to_link_block
)
minibatch_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
minibatch_link_collator
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
neighbor_dp
=
Sampler
(
data_block
_converter
,
graph
,
fanouts
)
neighbor_dp
=
Sampler
(
minibatch
_converter
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
...
...
@@ -160,9 +163,11 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
data_block_converter
=
Mapper
(
item_sampler_dp
,
to_link_block
)
minibatch_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
minibatch_link_collator
)
negative_dp
=
gb
.
UniformNegativeSampler
(
data_block
_converter
,
1
,
format
,
graph
minibatch
_converter
,
1
,
format
,
graph
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
neighbor_dp
=
Sampler
(
negative_dp
,
graph
,
fanouts
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment