Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
1ff5f09f
"tests/vscode:/vscode.git/clone" did not exist on "e2dfe0d17be21ce08e759dab4712bc37debd44c4"
Unverified
Commit
1ff5f09f
authored
Sep 08, 2023
by
peizhou001
Committed by
GitHub
Sep 08, 2023
Browse files
[Graphbolt] Remove link prediction format (#6298)
parent
09c8e8d9
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
33 additions
and
258 deletions
+33
-258
python/dgl/graphbolt/__init__.py
python/dgl/graphbolt/__init__.py
+0
-1
python/dgl/graphbolt/data_format.py
python/dgl/graphbolt/data_format.py
+0
-40
python/dgl/graphbolt/impl/uniform_negative_sampler.py
python/dgl/graphbolt/impl/uniform_negative_sampler.py
+6
-27
python/dgl/graphbolt/negative_sampler.py
python/dgl/graphbolt/negative_sampler.py
+10
-55
python/dgl/graphbolt/subgraph_sampler.py
python/dgl/graphbolt/subgraph_sampler.py
+6
-0
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
+7
-112
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+4
-23
No files found.
python/dgl/graphbolt/__init__.py
View file @
1ff5f09f
...
...
@@ -7,7 +7,6 @@ import torch
from
.._ffi
import
libinfo
from
.base
import
*
from
.minibatch
import
*
from
.data_format
import
*
from
.dataloader
import
*
from
.dataset
import
*
from
.feature_fetcher
import
*
...
...
python/dgl/graphbolt/data_format.py
deleted
100644 → 0
View file @
09c8e8d9
"""Data format enums for graphbolt."""
from
enum
import
Enum
__all__
=
[
"LinkPredictionEdgeFormat"
]
class
LinkPredictionEdgeFormat
(
Enum
):
"""
An Enum class representing the formats of positive and negative edges used
in link prediction:
Attributes:
INDEPENDENT: Represents the 'independent' format where data is structured
as triples `(u, v, label)` indicating the source and destination nodes of
an edge, with a label (0 or 1) denoting it as negative or positive.
CONDITIONED: Represents the 'conditioned' format where data is structured
as quadruples `(u, v, neg_u, neg_v)` indicating the source and destination
nodes of positive and negative edges. And 'u' with 'v' are 1D tensors with
the same shape, while 'neg_u' and 'neg_v' are 2D tensors with the same
shape.
HEAD_CONDITIONED: Represents the 'head conditioned' format where data is
structured as triples `(u, v, neg_u)`, where '(u, v)' signifies the
source and destination nodes of positive edges, while each node in
'neg_u' collaborates with 'v' to create negative edges. And 'u' and 'v' are
1D tensors with the same shape, while 'neg_u' is a 2D tensor.
TAIL_CONDITIONED: Represents the 'tail conditioned' format where data is
structured as triples `(u, v, neg_v)`, where '(u, v)' signifies the
source and destination nodes of positive edges, while 'u' collaborates
with each node in 'neg_v' to create negative edges. And 'u' and 'v' are
1D tensors with the same shape, while 'neg_v' is a 2D tensor.
"""
INDEPENDENT
=
"independent"
CONDITIONED
=
"conditioned"
HEAD_CONDITIONED
=
"head_conditioned"
TAIL_CONDITIONED
=
"tail_conditioned"
python/dgl/graphbolt/impl/uniform_negative_sampler.py
View file @
1ff5f09f
...
...
@@ -21,7 +21,6 @@ class UniformNegativeSampler(NegativeSampler):
self
,
datapipe
,
negative_ratio
,
output_format
,
graph
,
):
"""
...
...
@@ -33,8 +32,6 @@ class UniformNegativeSampler(NegativeSampler):
The datapipe.
negative_ratio : int
The proportion of negative samples to positive samples.
output_format : LinkPredictionEdgeFormat
Determines the format of the output data.
graph : CSCSamplingGraph
The graph on which to perform negative sampling.
...
...
@@ -44,39 +41,21 @@ class UniformNegativeSampler(NegativeSampler):
>>> indptr = torch.LongTensor([0, 2, 4, 5])
>>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices)
>>> output_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = torch.tensor([[0, 1], [1, 2]])
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> neg_sampler = gb.UniformNegativeSampler(
...item_sampler, 2, output_format, graph)
>>> for data in neg_sampler:
... print(data.node_pairs, data.negative_dsts)
...item_sampler, 2, graph)
>>> for minibatch in neg_sampler:
... print(minibatch.negative_srcs)
... print(minibatch.negative_dsts)
...
(tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([1, 0, 0]))
(tensor([1, 1, 1]), tensor([2, 1, 2]), tensor([1, 0, 0]))
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5])
>>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices)
>>> output_format = gb.LinkPredictionEdgeFormat.CONDITIONED
>>> node_pairs = torch.tensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> neg_sampler = gb.UniformNegativeSampler(
...item_sampler, 2, output_format, graph)
>>> for data in neg_sampler:
... print(data.node_pairs, data.negative_dsts)
...
(tensor([0]), tensor([1]), tensor([[0, 0]]), tensor([[0, 1]]))
(tensor([1]), tensor([2]), tensor([[1, 1]]), tensor([[0, 1]]))
"""
super
().
__init__
(
datapipe
,
negative_ratio
,
output_format
)
super
().
__init__
(
datapipe
,
negative_ratio
)
self
.
graph
=
graph
def
_sample_with_etype
(
self
,
node_pairs
,
etype
=
None
):
...
...
python/dgl/graphbolt/negative_sampler.py
View file @
1ff5f09f
...
...
@@ -2,12 +2,9 @@
from
_collections_abc
import
Mapping
import
torch
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
Mapper
from
.data_format
import
LinkPredictionEdgeFormat
@
functional_datapipe
(
"sample_negative"
)
class
NegativeSampler
(
Mapper
):
...
...
@@ -20,7 +17,6 @@ class NegativeSampler(Mapper):
self
,
datapipe
,
negative_ratio
,
output_format
,
):
"""
Initlization for a negative sampler.
...
...
@@ -31,13 +27,10 @@ class NegativeSampler(Mapper):
The datapipe.
negative_ratio : int
The proportion of negative samples to positive samples.
output_format : LinkPredictionEdgeFormat
Determines the edge format of the output minibatch.
"""
super
().
__init__
(
datapipe
,
self
.
_sample
)
assert
negative_ratio
>
0
,
"Negative_ratio should be positive Integer."
self
.
negative_ratio
=
negative_ratio
self
.
output_format
=
output_format
def
_sample
(
self
,
minibatch
):
"""
...
...
@@ -61,18 +54,11 @@ class NegativeSampler(Mapper):
node_pairs
=
minibatch
.
node_pairs
assert
node_pairs
is
not
None
if
isinstance
(
node_pairs
,
Mapping
):
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
minibatch
.
labels
=
{}
else
:
minibatch
.
negative_srcs
,
minibatch
.
negative_dsts
=
{},
{}
for
etype
,
pos_pairs
in
node_pairs
.
items
():
self
.
_collate
(
minibatch
,
self
.
_sample_with_etype
(
pos_pairs
,
etype
),
etype
)
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
:
minibatch
.
negative_dsts
=
None
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
:
minibatch
.
negative_srcs
=
None
else
:
self
.
_collate
(
minibatch
,
self
.
_sample_with_etype
(
node_pairs
))
return
minibatch
...
...
@@ -112,42 +98,11 @@ class NegativeSampler(Mapper):
etype : str
Canonical edge type.
"""
pos_src
,
pos_dst
=
(
minibatch
.
node_pairs
[
etype
]
if
etype
is
not
None
else
minibatch
.
node_pairs
)
neg_src
,
neg_dst
=
neg_pairs
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
pos_labels
=
torch
.
ones_like
(
pos_src
)
neg_labels
=
torch
.
zeros_like
(
neg_src
)
src
=
torch
.
cat
([
pos_src
,
neg_src
])
dst
=
torch
.
cat
([
pos_dst
,
neg_dst
])
labels
=
torch
.
cat
([
pos_labels
,
neg_labels
])
if
etype
is
not
None
:
minibatch
.
node_pairs
[
etype
]
=
(
src
,
dst
)
minibatch
.
labels
[
etype
]
=
labels
else
:
minibatch
.
node_pairs
=
(
src
,
dst
)
minibatch
.
labels
=
labels
else
:
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
CONDITIONED
:
if
neg_src
is
not
None
:
neg_src
=
neg_src
.
view
(
-
1
,
self
.
negative_ratio
)
if
neg_dst
is
not
None
:
neg_dst
=
neg_dst
.
view
(
-
1
,
self
.
negative_ratio
)
elif
(
self
.
output_format
==
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
):
neg_src
=
neg_src
.
view
(
-
1
,
self
.
negative_ratio
)
neg_dst
=
None
elif
(
self
.
output_format
==
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
):
neg_dst
=
neg_dst
.
view
(
-
1
,
self
.
negative_ratio
)
neg_src
=
None
else
:
raise
TypeError
(
f
"Unsupported output format
{
self
.
output_format
}
."
)
if
etype
is
not
None
:
minibatch
.
negative_srcs
[
etype
]
=
neg_src
minibatch
.
negative_dsts
[
etype
]
=
neg_dst
...
...
python/dgl/graphbolt/subgraph_sampler.py
View file @
1ff5f09f
...
...
@@ -57,6 +57,12 @@ class SubgraphSampler(Mapper):
has_neg_dst
=
neg_dst
is
not
None
is_heterogeneous
=
isinstance
(
node_pairs
,
Dict
)
if
is_heterogeneous
:
has_neg_src
=
has_neg_src
and
all
(
item
is
not
None
for
item
in
neg_src
.
values
()
)
has_neg_dst
=
has_neg_dst
and
all
(
item
is
not
None
for
item
in
neg_dst
.
values
()
)
# Collect nodes from all types of input.
nodes
=
defaultdict
(
list
)
for
etype
,
(
src
,
dst
)
in
node_pairs
.
items
():
...
...
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
View file @
1ff5f09f
...
...
@@ -2,7 +2,6 @@ import dgl.graphbolt as gb
import
gb_test_utils
import
pytest
import
torch
from
torchdata.datapipes.iter
import
Mapper
def
test_NegativeSampler_invoke
():
...
...
@@ -19,7 +18,6 @@ def test_NegativeSampler_invoke():
negative_sampler
=
gb
.
NegativeSampler
(
item_sampler
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
)
with
pytest
.
raises
(
NotImplementedError
):
next
(
iter
(
negative_sampler
))
...
...
@@ -27,7 +25,6 @@ def test_NegativeSampler_invoke():
# Invoke NegativeSampler via functional form.
negative_sampler
=
item_sampler
.
sample_negative
(
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
)
with
pytest
.
raises
(
NotImplementedError
):
next
(
iter
(
negative_sampler
))
...
...
@@ -47,20 +44,16 @@ def test_UniformNegativeSampler_invoke():
# Verify iteration over UniformNegativeSampler.
def
_verify
(
negative_sampler
):
for
data
in
negative_sampler
:
src
,
dst
=
data
.
node_pairs
labels
=
data
.
labels
# Assertation
assert
len
(
src
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
dst
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
labels
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
torch
.
all
(
torch
.
eq
(
labels
[:
batch_size
],
1
))
assert
torch
.
all
(
torch
.
eq
(
labels
[
batch_size
:],
0
))
assert
data
.
negative_srcs
.
size
(
0
)
==
batch_size
assert
data
.
negative_srcs
.
size
(
1
)
==
negative_ratio
assert
data
.
negative_dsts
.
size
(
0
)
==
batch_size
assert
data
.
negative_dsts
.
size
(
1
)
==
negative_ratio
# Invoke UniformNegativeSampler via class constructor.
negative_sampler
=
gb
.
UniformNegativeSampler
(
item_sampler
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
graph
,
)
_verify
(
negative_sampler
)
...
...
@@ -68,14 +61,13 @@ def test_UniformNegativeSampler_invoke():
# Invoke UniformNegativeSampler via functional form.
negative_sampler
=
item_sampler
.
sample_uniform_negative
(
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
graph
,
)
_verify
(
negative_sampler
)
@
pytest
.
mark
.
parametrize
(
"negative_ratio"
,
[
1
,
5
,
10
,
20
])
def
test_NegativeSampler
_Independent_Format
(
negative_ratio
):
def
test_
Uniform_
NegativeSampler
(
negative_ratio
):
# Construct CSCSamplingGraph.
graph
=
gb_test_utils
.
rand_csc_graph
(
100
,
0.05
)
num_seeds
=
30
...
...
@@ -88,36 +80,6 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
negative_sampler
=
gb
.
UniformNegativeSampler
(
item_sampler
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
graph
,
)
# Perform Negative sampling.
for
data
in
negative_sampler
:
src
,
dst
=
data
.
node_pairs
labels
=
data
.
labels
# Assertation
assert
len
(
src
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
dst
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
labels
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
torch
.
all
(
torch
.
eq
(
labels
[:
batch_size
],
1
))
assert
torch
.
all
(
torch
.
eq
(
labels
[
batch_size
:],
0
))
@
pytest
.
mark
.
parametrize
(
"negative_ratio"
,
[
1
,
5
,
10
,
20
])
def
test_NegativeSampler_Conditioned_Format
(
negative_ratio
):
# Construct CSCSamplingGraph.
graph
=
gb_test_utils
.
rand_csc_graph
(
100
,
0.05
)
num_seeds
=
30
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
num_seeds
*
2
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
batch_size
=
10
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
item_sampler
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
CONDITIONED
,
graph
,
)
# Perform Negative sampling.
...
...
@@ -135,64 +97,6 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
assert
torch
.
equal
(
expected_src
,
neg_src
)
@
pytest
.
mark
.
parametrize
(
"negative_ratio"
,
[
1
,
5
,
10
,
20
])
def
test_NegativeSampler_Head_Conditioned_Format
(
negative_ratio
):
# Construct CSCSamplingGraph.
graph
=
gb_test_utils
.
rand_csc_graph
(
100
,
0.05
)
num_seeds
=
30
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
num_seeds
*
2
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
batch_size
=
10
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
item_sampler
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
,
graph
,
)
# Perform Negative sampling.
for
data
in
negative_sampler
:
pos_src
,
pos_dst
=
data
.
node_pairs
neg_src
=
data
.
negative_srcs
# Assertation
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
assert
len
(
neg_src
)
==
batch_size
assert
neg_src
.
numel
()
==
batch_size
*
negative_ratio
expected_src
=
pos_src
.
repeat
(
negative_ratio
).
view
(
-
1
,
negative_ratio
)
assert
torch
.
equal
(
expected_src
,
neg_src
)
@
pytest
.
mark
.
parametrize
(
"negative_ratio"
,
[
1
,
5
,
10
,
20
])
def
test_NegativeSampler_Tail_Conditioned_Format
(
negative_ratio
):
# Construct CSCSamplingGraph.
graph
=
gb_test_utils
.
rand_csc_graph
(
100
,
0.05
)
num_seeds
=
30
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
num_seeds
*
2
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
batch_size
=
10
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
item_sampler
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
,
graph
,
)
# Perform Negative sampling.
for
data
in
negative_sampler
:
pos_src
,
pos_dst
=
data
.
node_pairs
neg_dst
=
data
.
negative_dsts
# Assertation
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
assert
len
(
neg_dst
)
==
batch_size
assert
neg_dst
.
numel
()
==
batch_size
*
negative_ratio
def
get_hetero_graph
():
# COO graph:
# [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
...
...
@@ -215,16 +119,7 @@ def get_hetero_graph():
)
@
pytest
.
mark
.
parametrize
(
"format"
,
[
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
gb
.
LinkPredictionEdgeFormat
.
CONDITIONED
,
gb
.
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
,
gb
.
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
,
],
)
def
test_NegativeSampler_Hetero_Data
(
format
):
def
test_NegativeSampler_Hetero_Data
():
graph
=
get_hetero_graph
()
itemset
=
gb
.
ItemSetDict
(
{
...
...
@@ -240,5 +135,5 @@ def test_NegativeSampler_Hetero_Data(format):
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
negative_dp
=
gb
.
UniformNegativeSampler
(
item_sampler
,
1
,
format
,
graph
)
negative_dp
=
gb
.
UniformNegativeSampler
(
item_sampler
,
1
,
graph
)
assert
len
(
list
(
negative_dp
))
==
5
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
1ff5f09f
...
...
@@ -2,7 +2,6 @@ import dgl.graphbolt as gb
import
gb_test_utils
import
pytest
import
torch
import
torchdata.datapipes
as
dp
from
torchdata.datapipes.iter
import
Mapper
...
...
@@ -71,23 +70,14 @@ def test_SubgraphSampler_Link(labor):
assert
len
(
list
(
neighbor_dp
))
==
5
@
pytest
.
mark
.
parametrize
(
"format"
,
[
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
gb
.
LinkPredictionEdgeFormat
.
CONDITIONED
,
gb
.
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
,
gb
.
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
,
],
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Link_With_Negative
(
format
,
labor
):
def
test_SubgraphSampler_Link_With_Negative
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
20
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
negative_dp
=
gb
.
UniformNegativeSampler
(
item_sampler
,
1
,
format
,
graph
)
negative_dp
=
gb
.
UniformNegativeSampler
(
item_sampler
,
1
,
graph
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
neighbor_dp
=
Sampler
(
negative_dp
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
...
...
@@ -139,17 +129,8 @@ def test_SubgraphSampler_Link_Hetero(labor):
assert
len
(
list
(
neighbor_dp
))
==
5
@
pytest
.
mark
.
parametrize
(
"format"
,
[
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
gb
.
LinkPredictionEdgeFormat
.
CONDITIONED
,
gb
.
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
,
gb
.
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
,
],
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Link_Hetero_With_Negative
(
format
,
labor
):
def
test_SubgraphSampler_Link_Hetero_With_Negative
(
labor
):
graph
=
get_hetero_graph
()
itemset
=
gb
.
ItemSetDict
(
{
...
...
@@ -167,7 +148,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
negative_dp
=
gb
.
UniformNegativeSampler
(
item_sampler
,
1
,
format
,
graph
)
negative_dp
=
gb
.
UniformNegativeSampler
(
item_sampler
,
1
,
graph
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
neighbor_dp
=
Sampler
(
negative_dp
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment