Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c6cdeb6b
Unverified
Commit
c6cdeb6b
authored
Aug 21, 2023
by
peizhou001
Committed by
GitHub
Aug 21, 2023
Browse files
[Graphbolt] Adapt negative sampler input (#6177)
parent
405de769
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
12 deletions
+21
-12
python/dgl/graphbolt/negative_sampler.py
python/dgl/graphbolt/negative_sampler.py
+8
-8
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
+13
-4
No files found.
python/dgl/graphbolt/negative_sampler.py
View file @
c6cdeb6b
...
...
@@ -6,7 +6,6 @@ import torch
from
torchdata.datapipes.iter
import
Mapper
from
.data_format
import
LinkPredictionEdgeFormat
from
.link_prediction_block
import
LinkPredictionBlock
class
NegativeSampler
(
Mapper
):
...
...
@@ -38,16 +37,18 @@ class NegativeSampler(Mapper):
self
.
negative_ratio
=
negative_ratio
self
.
output_format
=
output_format
def
_sample
(
self
,
node_pairs
):
def
_sample
(
self
,
data
):
"""
Generate a mix of positive and negative samples.
Parameters
----------
node_pairs : Tuple[Tensor] or Dict[etype, Tuple[Tensor]]
A tuple of tensors or a dictionary represents source-destination
node pairs of positive edges, where positive means the edge must
exist in the graph.
data : LinkPredictionBlock
An instance of 'LinkPredictionBlock' class requires the 'node_pair'
field. This function is responsible for generating negative edges
corresponding to the positive edges defined by the 'node_pair'. In
cases where negative edges already exist, this function will
overwrite them.
Returns
-------
...
...
@@ -55,8 +56,7 @@ class NegativeSampler(Mapper):
An instance of 'LinkPredictionBlock' encompasses both positive and
negative samples.
"""
data
=
LinkPredictionBlock
(
node_pair
=
node_pairs
)
node_pairs
=
data
.
node_pair
if
isinstance
(
node_pairs
,
Mapping
):
for
etype
,
pos_pairs
in
node_pairs
.
items
():
self
.
_collate
(
...
...
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
View file @
c6cdeb6b
...
...
@@ -2,6 +2,11 @@ import dgl.graphbolt as gb
import
gb_test_utils
import
pytest
import
torch
from
torchdata.datapipes.iter
import
Mapper
def
to_data_block
(
data
):
return
gb
.
LinkPredictionBlock
(
node_pair
=
data
)
@
pytest
.
mark
.
parametrize
(
"negative_ratio"
,
[
1
,
5
,
10
,
20
])
...
...
@@ -17,9 +22,10 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
)
batch_size
=
10
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
)
data_block_converter
=
Mapper
(
minibatch_sampler
,
to_data_block
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
minibatch_sampl
er
,
data_block_convert
er
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
graph
,
...
...
@@ -49,9 +55,10 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
)
batch_size
=
10
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
)
data_block_converter
=
Mapper
(
minibatch_sampler
,
to_data_block
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
minibatch_sampl
er
,
data_block_convert
er
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
CONDITIONED
,
graph
,
...
...
@@ -84,9 +91,10 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
)
batch_size
=
10
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
)
data_block_converter
=
Mapper
(
minibatch_sampler
,
to_data_block
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
minibatch_sampl
er
,
data_block_convert
er
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
,
graph
,
...
...
@@ -117,9 +125,10 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
)
batch_size
=
10
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
)
data_block_converter
=
Mapper
(
minibatch_sampler
,
to_data_block
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
minibatch_sampl
er
,
data_block_convert
er
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
,
graph
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment