Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
44f4b0e2
"...pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "85c030f67332a60b6dd8d4259c4d233767517c2c"
Unverified
Commit
44f4b0e2
authored
Aug 18, 2023
by
peizhou001
Committed by
GitHub
Aug 18, 2023
Browse files
[Graphbolt]Adapt negative sampler (#6165)
parent
ceb25724
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
57 additions
and
46 deletions
+57
-46
python/dgl/graphbolt/negative_sampler.py
python/dgl/graphbolt/negative_sampler.py
+49
-42
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
+8
-4
No files found.
python/dgl/graphbolt/negative_sampler.py
View file @
44f4b0e2
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
from
torchdata.datapipes.iter
import
Mapper
from
torchdata.datapipes.iter
import
Mapper
from
.data_format
import
LinkPredictionEdgeFormat
from
.data_format
import
LinkPredictionEdgeFormat
from
.link_prediction_block
import
LinkPredictionBlock
class
NegativeSampler
(
Mapper
):
class
NegativeSampler
(
Mapper
):
...
@@ -50,21 +51,20 @@ class NegativeSampler(Mapper):
...
@@ -50,21 +51,20 @@ class NegativeSampler(Mapper):
Returns
Returns
-------
-------
Tuple[Tensor] or Dict[etype, Tuple[Tensor]]
LinkPredictionBlock
A
collection of edges or a dictionary that maps etypes to edges,
A
n instance of 'LinkPredictionBlock' encompasses both positive and
which includes both positive and
negative samples.
negative samples.
"""
"""
data
=
LinkPredictionBlock
(
node_pair
=
node_pairs
)
if
isinstance
(
node_pairs
,
Mapping
):
if
isinstance
(
node_pairs
,
Mapping
):
return
{
for
etype
,
pos_pairs
in
node_pairs
.
items
():
etype
:
self
.
_collate
(
self
.
_collate
(
pos_pairs
,
self
.
_sample_with_etype
(
pos_pairs
,
etype
)
data
,
self
.
_sample_with_etype
(
pos_pairs
,
etype
)
,
etype
)
)
for
etype
,
pos_pairs
in
node_pairs
.
items
()
}
else
:
else
:
return
self
.
_collate
(
self
.
_collate
(
data
,
self
.
_sample_with_etype
(
node_pairs
))
node_pairs
,
self
.
_sample_with_etype
(
node_pairs
,
None
)
return
data
)
def
_sample_with_etype
(
self
,
node_pairs
,
etype
=
None
):
def
_sample_with_etype
(
self
,
node_pairs
,
etype
=
None
):
"""Generate negative pairs for a given etype form positive pairs
"""Generate negative pairs for a given etype form positive pairs
...
@@ -86,49 +86,56 @@ class NegativeSampler(Mapper):
...
@@ -86,49 +86,56 @@ class NegativeSampler(Mapper):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
_collate
(
self
,
pos_pairs
,
neg_pairs
):
def
_collate
(
self
,
data
,
neg_pairs
,
etype
=
None
):
"""Collates positive and negative samples.
"""Collates positive and negative samples
into data
.
Parameters
Parameters
----------
----------
pos_pairs : Tuple[Tensor]
data : LinkPredictionBlock
A tuple of tensors represents source-destination node pairs of
The input data, which contains positive node pairs, will be filled
positive edges, where positive means the edge must exist in
with negative information in this function.
the graph.
neg_pairs : Tuple[Tensor]
neg_pairs : Tuple[Tensor]
A tuple of tensors represents source-destination node pairs of
A tuple of tensors represents source-destination node pairs of
negative edges, where negative means the edge may not exist in
negative edges, where negative means the edge may not exist in
the graph.
the graph.
etype : (str, str, str)
Returns
Canonical edge type.
-------
Tuple[Tensor]
A mixed collection of positive and negative node pairs.
"""
"""
pos_src
,
pos_dst
=
data
.
node_pair
neg_src
,
neg_dst
=
neg_pairs
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
pos_src
,
pos_dst
=
pos_pairs
neg_src
,
neg_dst
=
neg_pairs
pos_label
=
torch
.
ones_like
(
pos_src
)
pos_label
=
torch
.
ones_like
(
pos_src
)
neg_label
=
torch
.
zeros_like
(
neg_src
)
neg_label
=
torch
.
zeros_like
(
neg_src
)
src
=
torch
.
cat
([
pos_src
,
neg_src
])
src
=
torch
.
cat
([
pos_src
,
neg_src
])
dst
=
torch
.
cat
([
pos_dst
,
neg_dst
])
dst
=
torch
.
cat
([
pos_dst
,
neg_dst
])
label
=
torch
.
cat
([
pos_label
,
neg_label
])
label
=
torch
.
cat
([
pos_label
,
neg_label
])
return
(
src
,
dst
,
label
)
if
etype
:
elif
self
.
output_format
==
LinkPredictionEdgeFormat
.
CONDITIONED
:
data
.
node_pair
[
etype
]
=
(
src
,
dst
)
pos_src
,
pos_dst
=
pos_pairs
data
.
label
[
etype
]
=
label
neg_src
,
neg_dst
=
neg_pairs
else
:
neg_src
=
neg_src
.
view
(
-
1
,
self
.
negative_ratio
)
data
.
node_pair
=
(
src
,
dst
)
neg_dst
=
neg_dst
.
view
(
-
1
,
self
.
negative_ratio
)
data
.
label
=
label
return
(
pos_src
,
pos_dst
,
neg_src
,
neg_dst
)
elif
self
.
output_format
==
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
:
pos_src
,
pos_dst
=
pos_pairs
neg_src
,
_
=
neg_pairs
neg_src
=
neg_src
.
view
(
-
1
,
self
.
negative_ratio
)
return
(
pos_src
,
pos_dst
,
neg_src
)
elif
self
.
output_format
==
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
:
pos_src
,
pos_dst
=
pos_pairs
_
,
neg_dst
=
neg_pairs
neg_dst
=
neg_dst
.
view
(
-
1
,
self
.
negative_ratio
)
return
(
pos_src
,
pos_dst
,
neg_dst
)
else
:
else
:
raise
ValueError
(
"Unsupported output format."
)
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
CONDITIONED
:
neg_src
=
neg_src
.
view
(
-
1
,
self
.
negative_ratio
)
neg_dst
=
neg_dst
.
view
(
-
1
,
self
.
negative_ratio
)
elif
(
self
.
output_format
==
LinkPredictionEdgeFormat
.
HEAD_CONDITIONED
):
neg_src
=
neg_src
.
view
(
-
1
,
self
.
negative_ratio
)
neg_dst
=
None
elif
(
self
.
output_format
==
LinkPredictionEdgeFormat
.
TAIL_CONDITIONED
):
neg_dst
=
neg_dst
.
view
(
-
1
,
self
.
negative_ratio
)
neg_src
=
None
else
:
raise
TypeError
(
f
"Unsupported output format
{
self
.
output_format
}
."
)
if
etype
:
data
.
negative_head
[
etype
]
=
neg_src
data
.
negative_tail
[
etype
]
=
neg_dst
else
:
data
.
negative_head
=
neg_src
data
.
negative_tail
=
neg_dst
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
View file @
44f4b0e2
...
@@ -26,7 +26,8 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
...
@@ -26,7 +26,8 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
)
)
# Perform Negative sampling.
# Perform Negative sampling.
for
data
in
negative_sampler
:
for
data
in
negative_sampler
:
src
,
dst
,
label
=
data
src
,
dst
=
data
.
node_pair
label
=
data
.
label
# Assertation
# Assertation
assert
len
(
src
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
src
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
dst
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
dst
)
==
batch_size
*
(
negative_ratio
+
1
)
...
@@ -57,7 +58,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
...
@@ -57,7 +58,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
)
)
# Perform Negative sampling.
# Perform Negative sampling.
for
data
in
negative_sampler
:
for
data
in
negative_sampler
:
pos_src
,
pos_dst
,
neg_src
,
neg_dst
=
data
pos_src
,
pos_dst
=
data
.
node_pair
neg_src
,
neg_dst
=
data
.
negative_head
,
data
.
negative_tail
# Assertation
# Assertation
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
...
@@ -91,7 +93,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
...
@@ -91,7 +93,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
)
)
# Perform Negative sampling.
# Perform Negative sampling.
for
data
in
negative_sampler
:
for
data
in
negative_sampler
:
pos_src
,
pos_dst
,
neg_src
=
data
pos_src
,
pos_dst
=
data
.
node_pair
neg_src
=
data
.
negative_head
# Assertation
# Assertation
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
...
@@ -123,7 +126,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
...
@@ -123,7 +126,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
)
)
# Perform Negative sampling.
# Perform Negative sampling.
for
data
in
negative_sampler
:
for
data
in
negative_sampler
:
pos_src
,
pos_dst
,
neg_dst
=
data
pos_src
,
pos_dst
=
data
.
node_pair
neg_dst
=
data
.
negative_tail
# Assertation
# Assertation
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment