Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
44f4b0e2
Unverified
Commit
44f4b0e2
authored
Aug 18, 2023
by
peizhou001
Committed by
GitHub
Aug 18, 2023
Browse files
[Graphbolt]Adapt negative sampler (#6165)
parent
ceb25724
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
57 additions
and
46 deletions
+57
-46
python/dgl/graphbolt/negative_sampler.py
python/dgl/graphbolt/negative_sampler.py
+49
-42
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
+8
-4
No files found.
python/dgl/graphbolt/negative_sampler.py
View file @
44f4b0e2
...
...
@@ -6,6 +6,7 @@ import torch
from
torchdata.datapipes.iter
import
Mapper
from
.data_format
import
LinkPredictionEdgeFormat
from
.link_prediction_block
import
LinkPredictionBlock
class
NegativeSampler
(
Mapper
):
...
...
@@ -50,21 +51,20 @@ class NegativeSampler(Mapper):
Returns
-------
Tuple[Tensor] or Dict[etype, Tuple[Tensor]]
A
collection of edges or a dictionary that maps etypes to edges,
which includes both positive and
negative samples.
LinkPredictionBlock
A
n instance of 'LinkPredictionBlock' encompasses both positive and
negative samples.
"""
data
=
LinkPredictionBlock
(
node_pair
=
node_pairs
)
if
isinstance
(
node_pairs
,
Mapping
):
return
{
etype
:
self
.
_collate
(
pos_pairs
,
self
.
_sample_with_etype
(
pos_pairs
,
etype
)
for
etype
,
pos_pairs
in
node_pairs
.
items
():
self
.
_collate
(
data
,
self
.
_sample_with_etype
(
pos_pairs
,
etype
)
,
etype
)
for
etype
,
pos_pairs
in
node_pairs
.
items
()
}
else
:
return
self
.
_collate
(
node_pairs
,
self
.
_sample_with_etype
(
node_pairs
,
None
)
)
self
.
_collate
(
data
,
self
.
_sample_with_etype
(
node_pairs
))
return
data
def
_sample_with_etype
(
self
,
node_pairs
,
etype
=
None
):
"""Generate negative pairs for a given etype form positive pairs
...
...
@@ -86,49 +86,56 @@ class NegativeSampler(Mapper):
"""
raise
NotImplementedError
def
_collate
(
self
,
pos_pairs
,
neg_pairs
):
"""Collates positive and negative samples.
def
_collate
(
self
,
data
,
neg_pairs
,
etype
=
None
):
"""Collates positive and negative samples
into data
.
Parameters
----------
pos_pairs : Tuple[Tensor]
A tuple of tensors represents source-destination node pairs of
positive edges, where positive means the edge must exist in
the graph.
data : LinkPredictionBlock
The input data, which contains positive node pairs, will be filled
with negative information in this function.
neg_pairs : Tuple[Tensor]
A tuple of tensors represents source-destination node pairs of
negative edges, where negative means the edge may not exist in
the graph.
Returns
-------
Tuple[Tensor]
A mixed collection of positive and negative node pairs.
etype : (str, str, str)
Canonical edge type.
"""
pos_src
,
pos_dst
=
data
.
node_pair
neg_src
,
neg_dst
=
neg_pairs
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
pos_src
,
pos_dst
=
pos_pairs
neg_src
,
neg_dst
=
neg_pairs
pos_label
=
torch
.
ones_like
(
pos_src
)
neg_label
=
torch
.
zeros_like
(
neg_src
)
src
=
torch
.
cat
([
pos_src
,
neg_src
])
dst
=
torch
.
cat
([
pos_dst
,
neg_dst
])
label
=
torch
.
cat
([
pos_label
,
neg_label
])
return
(
src
,
dst
,
label
)
elif
self
.
output_format
==
LinkPredictionEdgeFormat
.
CONDITIONED
:
pos_src
,
pos_dst
=
pos_pairs
neg_src
,
neg_dst
=
neg_pairs
neg_src
=
neg_src
.
view
(
-
1
,
self
.
negative_ratio
)
neg_dst
=
neg_dst
.
view
(
-
1
,
self
.
negative_ratio
)
return
(
pos_src
,
pos_dst
,
neg_src
,
neg_dst
)
elif
self
.
output_format
==
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
:
pos_src
,
pos_dst
=
pos_pairs
neg_src
,
_
=
neg_pairs
neg_src
=
neg_src
.
view
(
-
1
,
self
.
negative_ratio
)
return
(
pos_src
,
pos_dst
,
neg_src
)
elif
self
.
output_format
==
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
:
pos_src
,
pos_dst
=
pos_pairs
_
,
neg_dst
=
neg_pairs
neg_dst
=
neg_dst
.
view
(
-
1
,
self
.
negative_ratio
)
return
(
pos_src
,
pos_dst
,
neg_dst
)
if
etype
:
data
.
node_pair
[
etype
]
=
(
src
,
dst
)
data
.
label
[
etype
]
=
label
else
:
data
.
node_pair
=
(
src
,
dst
)
data
.
label
=
label
else
:
raise
ValueError
(
"Unsupported output format."
)
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
CONDITIONED
:
neg_src
=
neg_src
.
view
(
-
1
,
self
.
negative_ratio
)
neg_dst
=
neg_dst
.
view
(
-
1
,
self
.
negative_ratio
)
elif
(
self
.
output_format
==
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
):
neg_src
=
neg_src
.
view
(
-
1
,
self
.
negative_ratio
)
neg_dst
=
None
elif
(
self
.
output_format
==
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
):
neg_dst
=
neg_dst
.
view
(
-
1
,
self
.
negative_ratio
)
neg_src
=
None
else
:
raise
TypeError
(
f
"Unsupported output format
{
self
.
output_format
}
."
)
if
etype
:
data
.
negative_head
[
etype
]
=
neg_src
data
.
negative_tail
[
etype
]
=
neg_dst
else
:
data
.
negative_head
=
neg_src
data
.
negative_tail
=
neg_dst
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
View file @
44f4b0e2
...
...
@@ -26,7 +26,8 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
)
# Perform Negative sampling.
for
data
in
negative_sampler
:
src
,
dst
,
label
=
data
src
,
dst
=
data
.
node_pair
label
=
data
.
label
# Assertation
assert
len
(
src
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
dst
)
==
batch_size
*
(
negative_ratio
+
1
)
...
...
@@ -57,7 +58,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
)
# Perform Negative sampling.
for
data
in
negative_sampler
:
pos_src
,
pos_dst
,
neg_src
,
neg_dst
=
data
pos_src
,
pos_dst
=
data
.
node_pair
neg_src
,
neg_dst
=
data
.
negative_head
,
data
.
negative_tail
# Assertation
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
...
...
@@ -91,7 +93,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
)
# Perform Negative sampling.
for
data
in
negative_sampler
:
pos_src
,
pos_dst
,
neg_src
=
data
pos_src
,
pos_dst
=
data
.
node_pair
neg_src
=
data
.
negative_head
# Assertation
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
...
...
@@ -123,7 +126,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
)
# Perform Negative sampling.
for
data
in
negative_sampler
:
pos_src
,
pos_dst
,
neg_dst
=
data
pos_src
,
pos_dst
=
data
.
node_pair
neg_dst
=
data
.
negative_tail
# Assertation
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment