Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
5da7d391
Unverified
Commit
5da7d391
authored
Aug 01, 2023
by
peizhou001
Committed by
GitHub
Aug 01, 2023
Browse files
[Graphbolt]Add negative sampler udf (#6053)
parent
8c213ef1
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
320 additions
and
0 deletions
+320
-0
python/dgl/graphbolt/__init__.py
python/dgl/graphbolt/__init__.py
+2
-0
python/dgl/graphbolt/impl/__init__.py
python/dgl/graphbolt/impl/__init__.py
+1
-0
python/dgl/graphbolt/impl/uniform_negative_sampler.py
python/dgl/graphbolt/impl/uniform_negative_sampler.py
+93
-0
python/dgl/graphbolt/link_data_format.py
python/dgl/graphbolt/link_data_format.py
+23
-0
python/dgl/graphbolt/negative_sampler.py
python/dgl/graphbolt/negative_sampler.py
+132
-0
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
+69
-0
No files found.
python/dgl/graphbolt/__init__.py
View file @
5da7d391
...
@@ -15,6 +15,8 @@ from .impl import *
...
@@ -15,6 +15,8 @@ from .impl import *
from
.dataloader
import
*
from
.dataloader
import
*
from
.subgraph_sampler
import
*
from
.subgraph_sampler
import
*
from
.sampled_subgraph
import
*
from
.sampled_subgraph
import
*
from
.link_data_format
import
*
from
.negative_sampler
import
*
from
.utils
import
unique_and_compact_node_pairs
from
.utils
import
unique_and_compact_node_pairs
...
...
python/dgl/graphbolt/impl/__init__.py
View file @
5da7d391
...
@@ -4,3 +4,4 @@ from .ondisk_metadata import *
...
@@ -4,3 +4,4 @@ from .ondisk_metadata import *
from
.torch_based_feature_store
import
*
from
.torch_based_feature_store
import
*
from
.csc_sampling_graph
import
*
from
.csc_sampling_graph
import
*
from
.sampled_subgraph_impl
import
*
from
.sampled_subgraph_impl
import
*
from
.uniform_negative_sampler
import
*
python/dgl/graphbolt/impl/uniform_negative_sampler.py
0 → 100644
View file @
5da7d391
"""Uniform negative sampler for GraphBolt."""
from
..negative_sampler
import
NegativeSampler
class
UniformNegativeSampler
(
NegativeSampler
):
"""
Negative samplers randomly select negative destination nodes for each
source node based on a uniform distribution. It's important to note that
the term 'negative' refers to false negatives, indicating that the sampled
pairs are not ensured to be absent in the graph.
For each edge ``(u, v)``, it is supposed to generate `negative_ratio` pairs
of negative edges ``(u, v')``, where ``v'`` is chosen uniformly from all
the nodes in the graph.
"""
def
__init__
(
self
,
datapipe
,
negative_ratio
,
link_data_format
,
graph
,
):
"""
Initlization for a uniform negative sampler.
Parameters
----------
datapipe : DataPipe
The datapipe.
negative_ratio : int
The proportion of negative samples to positive samples.
link_data_format : LinkDataFormat
Determines the format of the output data:
- Conditioned format: Outputs data as quadruples
`[u, v, [negative heads], [negative tails]]`. Here, 'u' and 'v'
are the source and destination nodes of positive edges, while
'negative heads' and 'negative tails' refer to the source and
destination nodes of negative edges.
- Independent format: Outputs data as triples `[u, v, label]`.
In this case, 'u' and 'v' are the source and destination nodes
of an edge, and 'label' indicates whether the edge is negative
(0) or positive (1).
graph : CSCSamplingGraph
The graph on which to perform negative sampling.
Examples
--------
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5])
>>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices)
>>> link_data_format = gb.LinkDataFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler(
...item_set, batch_size=1,
...)
>>> neg_sampler = gb.UniformNegativeSampler(
...minibatch_sampler, 2, link_data_format, graph)
>>> for data in neg_sampler:
... print(data)
...
(tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([1, 0, 0]))
(tensor([1, 1, 1]), tensor([2, 1, 2]), tensor([1, 0, 0]))
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5])
>>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices)
>>> link_data_format = gb.LinkDataFormat.CONDITIONED
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler(
...item_set, batch_size=1,
...)
>>> neg_sampler = gb.UniformNegativeSampler(
...minibatch_sampler, 2, link_data_format, graph)
>>> for data in neg_sampler:
... print(data)
...
(tensor([0]), tensor([1]), tensor([[0, 0]]), tensor([[2, 1]]))
(tensor([1]), tensor([2]), tensor([[1, 1]]), tensor([[1, 2]]))
"""
super
().
__init__
(
datapipe
,
negative_ratio
,
link_data_format
)
self
.
graph
=
graph
def
_sample_with_etype
(
self
,
node_pairs
,
etype
=
None
):
return
self
.
graph
.
sample_negative_edges_uniform
(
etype
,
node_pairs
,
self
.
negative_ratio
,
)
python/dgl/graphbolt/link_data_format.py
0 → 100644
View file @
5da7d391
"""Linked data format."""
from
enum
import
Enum
__all__
=
[
"LinkDataFormat"
]
class
LinkDataFormat
(
Enum
):
"""
An Enum class representing the two data formats used in link prediction:
Attributes:
CONDITIONED: Represents the 'conditioned' format where data is
structured as quadruples `[u, v, [negative heads], [negative tails]]`
indicating the source and destination nodes of positive and negative edges.
INDEPENDENT: Represents the 'independent' format where data is structured
as triples `[u, v, label]` indicating the source and destination nodes of
an edge, with a label (0 or 1) denoting it as negative or positive.
"""
CONDITIONED
=
"conditioned"
INDEPENDENT
=
"independent"
python/dgl/graphbolt/negative_sampler.py
0 → 100644
View file @
5da7d391
"""Negative samplers."""
from
_collections_abc
import
Mapping
import
torch
from
torchdata.datapipes.iter
import
Mapper
from
.link_data_format
import
LinkDataFormat
class
NegativeSampler
(
Mapper
):
"""
A negative sampler used to generate negative samples and return
a mix of positive and negative samples.
"""
def
__init__
(
self
,
datapipe
,
negative_ratio
,
link_data_format
,
):
"""
Initlization for a negative sampler.
Parameters
----------
datapipe : DataPipe
The datapipe.
negative_ratio : int
The proportion of negative samples to positive samples.
link_data_format : LinkDataFormat
Determines the format of the output data:
- Conditioned format: Outputs data as quadruples
`[u, v, [negative heads], [negative tails]]`. Here, 'u' and 'v'
are the source and destination nodes of positive edges, while
'negative heads' and 'negative tails' refer to the source and
destination nodes of negative edges.
- Independent format: Outputs data as triples `[u, v, label]`.
In this case, 'u' and 'v' are the source and destination nodes
of an edge, and 'label' indicates whether the edge is negative
(0) or positive (1).
"""
super
().
__init__
(
datapipe
,
self
.
_sample
)
assert
negative_ratio
>
0
,
"Negative_ratio should be positive Integer."
self
.
negative_ratio
=
negative_ratio
self
.
link_data_format
=
link_data_format
def
_sample
(
self
,
node_pairs
):
"""
Generate a mix of positive and negative samples.
Parameters
----------
node_pairs : Tuple[Tensor] or Dict[etype, Tuple[Tensor]]
A tuple of tensors or a dictionary represents source-destination
node pairs of positive edges, where positive means the edge must
exist in the graph.
Returns
-------
Tuple[Tensor] or Dict[etype, Tuple[Tensor]]
A collection of edges or a dictionary that maps etypes to edges,
which includes both positive and negative samples.
"""
if
isinstance
(
node_pairs
,
Mapping
):
return
{
etype
:
self
.
_collate
(
pos_pairs
,
self
.
_sample_with_etype
(
pos_pairs
,
etype
)
)
for
etype
,
pos_pairs
in
node_pairs
.
items
()
}
else
:
return
self
.
_collate
(
node_pairs
,
self
.
_sample_with_etype
(
node_pairs
,
None
)
)
def
_sample_with_etype
(
self
,
node_pairs
,
etype
=
None
):
"""Generate negative pairs for a given etype form positive pairs
for a given etype.
Parameters
----------
node_pairs : Tuple[Tensor]
A tuple of tensors or a dictionary represents source-destination
node pairs of positive edges, where positive means the edge must
exist in the graph.
etype : (str, str, str)
Canonical edge type.
Returns
-------
Tuple[Tensor]
A collection of negative node pairs.
"""
def
_collate
(
self
,
pos_pairs
,
neg_pairs
):
"""Collates positive and negative samples.
Parameters
----------
pos_pairs : Tuple[Tensor]
A tuple of tensors represents source-destination node pairs of
positive edges, where positive means the edge must exist in
the graph.
neg_pairs : Tuple[Tensor]
A tuple of tensors represents source-destination node pairs of
negative edges, where negative means the edge may not exist in
the graph.
Returns
-------
Tuple[Tensor]
A mixed collection of positive and negative node pairs.
"""
if
self
.
link_data_format
==
LinkDataFormat
.
INDEPENDENT
:
pos_src
,
pos_dst
=
pos_pairs
neg_src
,
neg_dst
=
neg_pairs
pos_label
=
torch
.
ones_like
(
pos_src
)
neg_label
=
torch
.
zeros_like
(
neg_src
)
src
=
torch
.
cat
([
pos_src
,
neg_src
])
dst
=
torch
.
cat
([
pos_dst
,
neg_dst
])
label
=
torch
.
cat
([
pos_label
,
neg_label
])
return
(
src
,
dst
,
label
)
elif
self
.
link_data_format
==
LinkDataFormat
.
CONDITIONED
:
pos_src
,
pos_dst
=
pos_pairs
neg_src
,
neg_dst
=
neg_pairs
neg_src
=
neg_src
.
view
(
-
1
,
self
.
negative_ratio
)
neg_dst
=
neg_dst
.
view
(
-
1
,
self
.
negative_ratio
)
return
(
pos_src
,
pos_dst
,
neg_src
,
neg_dst
)
else
:
raise
ValueError
(
"Unsupported link data format."
)
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
0 → 100644
View file @
5da7d391
import
dgl.graphbolt
as
gb
import
gb_test_utils
import
pytest
import
torch
@
pytest
.
mark
.
parametrize
(
"negative_ratio"
,
[
1
,
5
,
10
,
20
])
def
test_NegativeSampler_Independent_Format
(
negative_ratio
):
# Construct CSCSamplingGraph.
graph
=
gb_test_utils
.
rand_csc_graph
(
100
,
0.05
)
num_seeds
=
30
item_set
=
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
num_seeds
),
torch
.
arange
(
num_seeds
,
num_seeds
*
2
),
)
)
batch_size
=
10
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
minibatch_sampler
,
negative_ratio
,
gb
.
LinkDataFormat
.
INDEPENDENT
,
graph
,
)
# Perform Negative sampling.
for
data
in
negative_sampler
:
src
,
dst
,
label
=
data
# Assertation
assert
len
(
src
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
dst
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
label
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
torch
.
all
(
torch
.
eq
(
label
[:
batch_size
],
1
))
assert
torch
.
all
(
torch
.
eq
(
label
[
batch_size
:],
0
))
@
pytest
.
mark
.
parametrize
(
"negative_ratio"
,
[
1
,
5
,
10
,
20
])
def
test_NegativeSampler_Conditioned_Format
(
negative_ratio
):
# Construct CSCSamplingGraph.
graph
=
gb_test_utils
.
rand_csc_graph
(
100
,
0.05
)
num_seeds
=
30
item_set
=
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
num_seeds
),
torch
.
arange
(
num_seeds
,
num_seeds
*
2
),
)
)
batch_size
=
10
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
)
# Construct NegativeSampler.
negative_sampler
=
gb
.
UniformNegativeSampler
(
minibatch_sampler
,
negative_ratio
,
gb
.
LinkDataFormat
.
CONDITIONED
,
graph
,
)
# Perform Negative sampling.
for
data
in
negative_sampler
:
pos_src
,
pos_dst
,
neg_src
,
neg_dst
=
data
# Assertation
assert
len
(
pos_src
)
==
batch_size
assert
len
(
pos_dst
)
==
batch_size
assert
len
(
neg_src
)
==
batch_size
assert
len
(
neg_dst
)
==
batch_size
assert
neg_src
.
numel
()
==
batch_size
*
negative_ratio
assert
neg_dst
.
numel
()
==
batch_size
*
negative_ratio
expected_src
=
pos_src
.
repeat
(
negative_ratio
).
view
(
-
1
,
negative_ratio
)
assert
torch
.
equal
(
expected_src
,
neg_src
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment