Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
79a95477
Unverified
Commit
79a95477
authored
Sep 07, 2023
by
Rhett Ying
Committed by
GitHub
Sep 07, 2023
Browse files
[GraphBolt] clean up minibatch collator in testcases and docstring (#6294)
parent
b4c351b4
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
62 additions
and
164 deletions
+62
-164
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+6
-18
python/dgl/graphbolt/impl/uniform_negative_sampler.py
python/dgl/graphbolt/impl/uniform_negative_sampler.py
+8
-8
tests/python/pytorch/graphbolt/gb_test_utils.py
tests/python/pytorch/graphbolt/gb_test_utils.py
+0
-10
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
+16
-47
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+7
-13
tests/python/pytorch/graphbolt/test_multi_process_dataloader.py
...python/pytorch/graphbolt/test_multi_process_dataloader.py
+2
-5
tests/python/pytorch/graphbolt/test_single_process_dataloader.py
...ython/pytorch/graphbolt/test_single_process_dataloader.py
+2
-5
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+21
-58
No files found.
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
79a95477
...
...
@@ -53,24 +53,18 @@ class NeighborSampler(SubgraphSampler):
-------
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> def minibatch_link_collator(data):
... minibatch = gb.MiniBatch(node_pairs=data)
... return minibatch
...
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.from_csc(indptr, indices)
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs =
(
torch.
t
ensor([0, 1]
)
,
torch.tensor(
[1, 2]
)
)
>>> item_set = gb.ItemSet(node_pairs)
>>> node_pairs = torch.
LongT
ensor([
[
0, 1], [1, 2]
]
)
>>> item_set = gb.ItemSet(node_pairs
, names="node_pairs"
)
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> minibatch_converter = Mapper(item_sampler,
...minibatch_link_collator)
>>> neg_sampler = gb.UniformNegativeSampler(
...
minibatch_convert
er, 2, data_format, graph)
...
item_sampl
er, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...torch.LongTensor([15])]
>>> subgraph_sampler = gb.NeighborSampler(
...
...
@@ -165,24 +159,18 @@ class LayerNeighborSampler(NeighborSampler):
-------
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> def minibatch_link_collator(data):
... minibatch = gb.MiniBatch(node_pairs=data)
... return minibatch
...
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.from_csc(indptr, indices)
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs =
(
torch.
t
ensor([0, 1]
)
,
torch.tensor(
[1, 2]
)
)
>>> item_set = gb.ItemSet(node_pairs)
>>> node_pairs = torch.
LongT
ensor([
[
0, 1], [1, 2]
]
)
>>> item_set = gb.ItemSet(node_pairs
, names="node_pairs"
)
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> minibatch_converter = Mapper(item_sampler,
...minibatch_link_collator)
>>> neg_sampler = gb.UniformNegativeSampler(
...
minibatch_convert
er, 2, data_format, graph)
...
item_sampl
er, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...torch.LongTensor([15])]
>>> subgraph_sampler = gb.LayerNeighborSampler(
...
...
python/dgl/graphbolt/impl/uniform_negative_sampler.py
View file @
79a95477
...
...
@@ -42,15 +42,15 @@ class UniformNegativeSampler(NegativeSampler):
>>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices)
>>> output_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs =
(
torch.tensor([0, 1]
)
,
torch.tensor(
[1, 2]
)
)
>>> item_set = gb.ItemSet(node_pairs)
>>> node_pairs = torch.tensor([
[
0, 1], [1, 2]
]
)
>>> item_set = gb.ItemSet(node_pairs
, names="node_pairs"
)
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> neg_sampler = gb.UniformNegativeSampler(
...item_sampler, 2, output_format, graph)
>>> for data in neg_sampler:
... print(data)
... print(data
.node_pairs, data.negative_dsts
)
...
(tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([1, 0, 0]))
(tensor([1, 1, 1]), tensor([2, 1, 2]), tensor([1, 0, 0]))
...
...
@@ -60,18 +60,18 @@ class UniformNegativeSampler(NegativeSampler):
>>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices)
>>> output_format = gb.LinkPredictionEdgeFormat.CONDITIONED
>>> node_pairs =
(
torch.tensor([0, 1]
)
,
torch.tensor(
[1, 2]
)
)
>>> item_set = gb.ItemSet(node_pairs)
>>> node_pairs = torch.tensor([
[
0, 1], [1, 2]
]
)
>>> item_set = gb.ItemSet(node_pairs
, names="node_pairs"
)
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> neg_sampler = gb.UniformNegativeSampler(
...item_sampler, 2, output_format, graph)
>>> for data in neg_sampler:
... print(data)
... print(data
.node_pairs, data.negative_dsts
)
...
(tensor([0]), tensor([1]), tensor([[0, 0]]), tensor([[
2
, 1]]))
(tensor([1]), tensor([2]), tensor([[1, 1]]), tensor([[
1
,
2
]]))
(tensor([0]), tensor([1]), tensor([[0, 0]]), tensor([[
0
, 1]]))
(tensor([1]), tensor([2]), tensor([[1, 1]]), tensor([[
0
,
1
]]))
"""
super
().
__init__
(
datapipe
,
negative_ratio
,
output_format
)
self
.
graph
=
graph
...
...
tests/python/pytorch/graphbolt/gb_test_utils.py
View file @
79a95477
...
...
@@ -8,16 +8,6 @@ import scipy.sparse as sp
import
torch
def
minibatch_node_collator
(
data
):
minibatch
=
gb
.
MiniBatch
(
seed_nodes
=
data
)
return
minibatch
def
minibatch_link_collator
(
data
):
minibatch
=
gb
.
MiniBatch
(
node_pairs
=
data
)
return
minibatch
def
rand_csc_graph
(
N
,
density
):
adj
=
sp
.
random
(
N
,
N
,
density
)
adj
=
adj
+
adj
.
T
...
...
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
View file @
79a95477
...
...
@@ -11,19 +11,13 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
graph
=
gb_test_utils
.
rand_csc_graph
(
100
,
0.05
)
num_seeds
=
30
item_set
=
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
num_seeds
),
torch
.
arange
(
num_seeds
,
num_seeds
*
2
),
)
torch
.
arange
(
0
,
num_seeds
*
2
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
batch_size
=
10
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
)
minibatch_converter
=
Mapper
(
item_sampler
,
gb_test_utils
.
minibatch_link_collator
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
minibatch_convert
er
,
item_sampl
er
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
graph
,
...
...
@@ -46,19 +40,13 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
graph
=
gb_test_utils
.
rand_csc_graph
(
100
,
0.05
)
num_seeds
=
30
item_set
=
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
num_seeds
),
torch
.
arange
(
num_seeds
,
num_seeds
*
2
),
)
torch
.
arange
(
0
,
num_seeds
*
2
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
batch_size
=
10
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
)
minibatch_converter
=
Mapper
(
item_sampler
,
gb_test_utils
.
minibatch_link_collator
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
minibatch_convert
er
,
item_sampl
er
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
CONDITIONED
,
graph
,
...
...
@@ -84,19 +72,13 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
graph
=
gb_test_utils
.
rand_csc_graph
(
100
,
0.05
)
num_seeds
=
30
item_set
=
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
num_seeds
),
torch
.
arange
(
num_seeds
,
num_seeds
*
2
),
)
torch
.
arange
(
0
,
num_seeds
*
2
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
batch_size
=
10
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
)
minibatch_converter
=
Mapper
(
item_sampler
,
gb_test_utils
.
minibatch_link_collator
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
minibatch_convert
er
,
item_sampl
er
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
,
graph
,
...
...
@@ -120,19 +102,13 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
graph
=
gb_test_utils
.
rand_csc_graph
(
100
,
0.05
)
num_seeds
=
30
item_set
=
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
num_seeds
),
torch
.
arange
(
num_seeds
,
num_seeds
*
2
),
)
torch
.
arange
(
0
,
num_seeds
*
2
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
batch_size
=
10
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
)
minibatch_converter
=
Mapper
(
item_sampler
,
gb_test_utils
.
minibatch_link_collator
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
minibatch_convert
er
,
item_sampl
er
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
,
graph
,
...
...
@@ -184,25 +160,18 @@ def test_NegativeSampler_Hetero_Data(format):
itemset
=
gb
.
ItemSetDict
(
{
"n1:e1:n2"
:
gb
.
ItemSet
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
]),
torch
.
LongTensor
([
0
,
2
,
0
,
1
]),
)
torch
.
LongTensor
([[
0
,
0
,
1
,
1
],
[
0
,
2
,
0
,
1
]]).
T
,
names
=
"node_pairs"
,
),
"n2:e2:n1"
:
gb
.
ItemSet
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
2
,
2
]),
torch
.
LongTensor
([
0
,
1
,
1
,
0
,
0
,
1
]),
)
torch
.
LongTensor
([[
0
,
0
,
1
,
1
,
2
,
2
],
[
0
,
1
,
1
,
0
,
0
,
1
]]).
T
,
names
=
"node_pairs"
,
),
}
)
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
minibatch_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
minibatch_link_collator
)
negative_dp
=
gb
.
UniformNegativeSampler
(
minibatch_converter
,
1
,
format
,
graph
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
negative_dp
=
gb
.
UniformNegativeSampler
(
item_sampler
,
1
,
format
,
graph
)
for
neg
in
negative_dp
:
print
(
neg
)
assert
len
(
list
(
negative_dp
))
==
5
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
79a95477
...
...
@@ -15,14 +15,11 @@ def test_FeatureFetcher_homo():
features
[
keys
[
1
]]
=
gb
.
TorchBasedFeature
(
b
)
feature_store
=
gb
.
BasicFeatureStore
(
features
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
))
item_sampler
_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
)
,
names
=
"seed_nodes"
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
minibatch_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
minibatch_node_collator
)
sampler_dp
=
gb
.
NeighborSampler
(
minibatch_converter
,
graph
,
fanouts
)
sampler_dp
=
gb
.
NeighborSampler
(
item_sampler
,
graph
,
fanouts
)
fetcher_dp
=
gb
.
FeatureFetcher
(
sampler_dp
,
feature_store
,
[
"a"
],
[
"b"
])
assert
len
(
list
(
fetcher_dp
))
==
5
...
...
@@ -99,17 +96,14 @@ def test_FeatureFetcher_hetero():
itemset
=
gb
.
ItemSetDict
(
{
"n1"
:
gb
.
ItemSet
(
torch
.
LongTensor
([
0
,
1
])),
"n2"
:
gb
.
ItemSet
(
torch
.
LongTensor
([
0
,
1
,
2
])),
"n1"
:
gb
.
ItemSet
(
torch
.
LongTensor
([
0
,
1
])
,
names
=
"seed_nodes"
),
"n2"
:
gb
.
ItemSet
(
torch
.
LongTensor
([
0
,
1
,
2
])
,
names
=
"seed_nodes"
),
}
)
item_sampler
_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
minibatch_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
minibatch_node_collator
)
sampler_dp
=
gb
.
NeighborSampler
(
minibatch_converter
,
graph
,
fanouts
)
sampler_dp
=
gb
.
NeighborSampler
(
item_sampler
,
graph
,
fanouts
)
fetcher_dp
=
gb
.
FeatureFetcher
(
sampler_dp
,
feature_store
,
{
"n1"
:
[
"a"
],
"n2"
:
[
"a"
]}
)
...
...
tests/python/pytorch/graphbolt/test_multi_process_dataloader.py
View file @
79a95477
...
...
@@ -13,7 +13,7 @@ from torchdata.datapipes.iter import Mapper
def
test_DataLoader
():
N
=
40
B
=
4
itemset
=
dgl
.
graphbolt
.
ItemSet
(
torch
.
arange
(
N
))
itemset
=
dgl
.
graphbolt
.
ItemSet
(
torch
.
arange
(
N
)
,
names
=
"seed_nodes"
)
graph
=
gb_test_utils
.
rand_csc_graph
(
200
,
0.15
)
features
=
{}
keys
=
[(
"node"
,
None
,
"a"
),
(
"node"
,
None
,
"b"
)]
...
...
@@ -22,11 +22,8 @@ def test_DataLoader():
feature_store
=
dgl
.
graphbolt
.
BasicFeatureStore
(
features
)
item_sampler
=
dgl
.
graphbolt
.
ItemSampler
(
itemset
,
batch_size
=
B
)
minibatch_converter
=
Mapper
(
item_sampler
,
gb_test_utils
.
minibatch_node_collator
)
subgraph_sampler
=
dgl
.
graphbolt
.
NeighborSampler
(
minibatch_convert
er
,
item_sampl
er
,
graph
,
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
2
)],
)
...
...
tests/python/pytorch/graphbolt/test_single_process_dataloader.py
View file @
79a95477
...
...
@@ -10,7 +10,7 @@ from torchdata.datapipes.iter import Mapper
def
test_DataLoader
():
N
=
32
B
=
4
itemset
=
dgl
.
graphbolt
.
ItemSet
(
torch
.
arange
(
N
))
itemset
=
dgl
.
graphbolt
.
ItemSet
(
torch
.
arange
(
N
)
,
names
=
"seed_nodes"
)
graph
=
gb_test_utils
.
rand_csc_graph
(
200
,
0.15
)
features
=
{}
...
...
@@ -20,11 +20,8 @@ def test_DataLoader():
feature_store
=
dgl
.
graphbolt
.
BasicFeatureStore
(
features
)
item_sampler
=
dgl
.
graphbolt
.
ItemSampler
(
itemset
,
batch_size
=
B
)
minibatch_converter
=
Mapper
(
item_sampler
,
gb_test_utils
.
minibatch_node_collator
)
subgraph_sampler
=
dgl
.
graphbolt
.
NeighborSampler
(
minibatch_convert
er
,
item_sampl
er
,
graph
,
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
2
)],
)
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
79a95477
...
...
@@ -9,15 +9,12 @@ from torchdata.datapipes.iter import Mapper
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Node
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
))
item_sampler
_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
)
,
names
=
"seed_nodes"
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
minibatch_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
minibatch_node_collator
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
sampler_dp
=
Sampler
(
minibatch_convert
er
,
graph
,
fanouts
)
sampler_dp
=
Sampler
(
item_sampl
er
,
graph
,
fanouts
)
assert
len
(
list
(
sampler_dp
))
==
5
...
...
@@ -29,20 +26,12 @@ def to_link_batch(data):
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Link
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
itemset
=
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
10
),
torch
.
arange
(
10
,
20
),
)
)
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
20
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
minibatch_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
minibatch_link_collator
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
neighbor_dp
=
Sampler
(
minibatch_convert
er
,
graph
,
fanouts
)
neighbor_dp
=
Sampler
(
item_sampl
er
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
...
...
@@ -58,21 +47,11 @@ def test_SubgraphSampler_Link(labor):
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Link_With_Negative
(
format
,
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
itemset
=
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
10
),
torch
.
arange
(
10
,
20
),
)
)
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
20
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
minibatch_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
minibatch_link_collator
)
negative_dp
=
gb
.
UniformNegativeSampler
(
minibatch_converter
,
1
,
format
,
graph
)
negative_dp
=
gb
.
UniformNegativeSampler
(
item_sampler
,
1
,
format
,
graph
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
neighbor_dp
=
Sampler
(
negative_dp
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
...
...
@@ -106,28 +85,21 @@ def test_SubgraphSampler_Link_Hetero(labor):
itemset
=
gb
.
ItemSetDict
(
{
"n1:e1:n2"
:
gb
.
ItemSet
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
]),
torch
.
LongTensor
([
0
,
2
,
0
,
1
]),
)
torch
.
LongTensor
([[
0
,
0
,
1
,
1
],
[
0
,
2
,
0
,
1
]]).
T
,
names
=
"node_pairs"
,
),
"n2:e2:n1"
:
gb
.
ItemSet
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
2
,
2
]),
torch
.
LongTensor
([
0
,
1
,
1
,
0
,
0
,
1
]),
)
torch
.
LongTensor
([[
0
,
0
,
1
,
1
,
2
,
2
],
[
0
,
1
,
1
,
0
,
0
,
1
]]).
T
,
names
=
"node_pairs"
,
),
}
)
item_sampler
_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
minibatch_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
minibatch_link_collator
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
neighbor_dp
=
Sampler
(
minibatch_convert
er
,
graph
,
fanouts
)
neighbor_dp
=
Sampler
(
item_sampl
er
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
...
...
@@ -146,29 +118,20 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
itemset
=
gb
.
ItemSetDict
(
{
"n1:e1:n2"
:
gb
.
ItemSet
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
]),
torch
.
LongTensor
([
0
,
2
,
0
,
1
]),
)
torch
.
LongTensor
([[
0
,
0
,
1
,
1
],
[
0
,
2
,
0
,
1
]]).
T
,
names
=
"node_pairs"
,
),
"n2:e2:n1"
:
gb
.
ItemSet
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
2
,
2
]),
torch
.
LongTensor
([
0
,
1
,
1
,
0
,
0
,
1
]),
)
torch
.
LongTensor
([[
0
,
0
,
1
,
1
,
2
,
2
],
[
0
,
1
,
1
,
0
,
0
,
1
]]).
T
,
names
=
"node_pairs"
,
),
}
)
item_sampler
_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
minibatch_converter
=
Mapper
(
item_sampler_dp
,
gb_test_utils
.
minibatch_link_collator
)
negative_dp
=
gb
.
UniformNegativeSampler
(
minibatch_converter
,
1
,
format
,
graph
)
negative_dp
=
gb
.
UniformNegativeSampler
(
item_sampler
,
1
,
format
,
graph
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
neighbor_dp
=
Sampler
(
negative_dp
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment