Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
911c4aba
"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "3cf42fd199259c061bb7749cb65757ba0d8a5b67"
Unverified
Commit
911c4aba
authored
Aug 30, 2023
by
Rhett Ying
Committed by
GitHub
Aug 30, 2023
Browse files
[GraphBolt] use str etype in sampler (#6232)
parent
0be8d8a7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
37 additions
and
28 deletions
+37
-28
python/dgl/graphbolt/impl/csc_sampling_graph.py
python/dgl/graphbolt/impl/csc_sampling_graph.py
+6
-5
python/dgl/graphbolt/negative_sampler.py
python/dgl/graphbolt/negative_sampler.py
+13
-11
python/dgl/graphbolt/subgraph_sampler.py
python/dgl/graphbolt/subgraph_sampler.py
+12
-6
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
+2
-2
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+4
-4
No files found.
python/dgl/graphbolt/impl/csc_sampling_graph.py
View file @
911c4aba
...
@@ -11,6 +11,7 @@ import torch
...
@@ -11,6 +11,7 @@ import torch
from
...base
import
ETYPE
from
...base
import
ETYPE
from
...convert
import
to_homogeneous
from
...convert
import
to_homogeneous
from
...heterograph
import
DGLGraph
from
...heterograph
import
DGLGraph
from
..base
import
etype_str_to_tuple
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
...
@@ -505,11 +506,11 @@ class CSCSamplingGraph:
...
@@ -505,11 +506,11 @@ class CSCSamplingGraph:
Parameters
Parameters
----------
----------
edge_type:
Tuple[
str
]
edge_type: str
The type of edges in the provided node_pairs. Any negative edges
The type of edges in the provided node_pairs. Any negative edges
sampled will also have the same type. If set to None, it will be
sampled will also have the same type. If set to None, it will be
considered as a homogeneous graph.
considered as a homogeneous graph.
node_pairs : Tuple[Tensor]
node_pairs : Tuple[
Tensor,
Tensor]
A tuple of two 1D tensors that represent the source and destination
A tuple of two 1D tensors that represent the source and destination
of positive edges, with 'positive' indicating that these edges are
of positive edges, with 'positive' indicating that these edges are
present in the graph. It's important to note that within the
present in the graph. It's important to note that within the
...
@@ -520,7 +521,7 @@ class CSCSamplingGraph:
...
@@ -520,7 +521,7 @@ class CSCSamplingGraph:
Returns
Returns
-------
-------
Tuple[Tensor]
Tuple[
Tensor,
Tensor]
A tuple consisting of two 1D tensors represents the source and
A tuple consisting of two 1D tensors represents the source and
destination of negative edges. In the context of a heterogeneous
destination of negative edges. In the context of a heterogeneous
graph, both the input nodes and the selected nodes are represented
graph, both the input nodes and the selected nodes are represented
...
@@ -528,12 +529,12 @@ class CSCSamplingGraph:
...
@@ -528,12 +529,12 @@ class CSCSamplingGraph:
`edge_type`. Note that negative refers to false negatives, which
`edge_type`. Note that negative refers to false negatives, which
means the edge could be present or not present in the graph.
means the edge could be present or not present in the graph.
"""
"""
if
edge_type
:
if
edge_type
is
not
None
:
assert
(
assert
(
self
.
node_type_offset
is
not
None
self
.
node_type_offset
is
not
None
),
"The 'node_type_offset' array is necessary for performing
\
),
"The 'node_type_offset' array is necessary for performing
\
negative sampling by edge type."
negative sampling by edge type."
_
,
_
,
dst_node_type
=
edge_type
_
,
_
,
dst_node_type
=
etype_str_to_tuple
(
edge_type
)
dst_node_type_id
=
self
.
metadata
.
node_type_to_id
[
dst_node_type
]
dst_node_type_id
=
self
.
metadata
.
node_type_to_id
[
dst_node_type
]
max_node_id
=
(
max_node_id
=
(
self
.
node_type_offset
[
dst_node_type_id
+
1
]
self
.
node_type_offset
[
dst_node_type_id
+
1
]
...
...
python/dgl/graphbolt/negative_sampler.py
View file @
911c4aba
...
@@ -80,16 +80,16 @@ class NegativeSampler(Mapper):
...
@@ -80,16 +80,16 @@ class NegativeSampler(Mapper):
Parameters
Parameters
----------
----------
node_pairs : Tuple[Tensor]
node_pairs : Tuple[
Tensor,
Tensor]
A tuple of tensors
or a dictionary
represent
s
source-destination
A tuple of tensors
that
represent source-destination
node pairs of
node pairs of
positive edges, where positive means the edge must
positive edges, where positive means the edge must
exist in the
exist in the
graph.
graph.
etype :
(
str
, str, str)
etype : str
Canonical edge type.
Canonical edge type.
Returns
Returns
-------
-------
Tuple[Tensor]
Tuple[
Tensor,
Tensor]
A collection of negative node pairs.
A collection of negative node pairs.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -102,14 +102,16 @@ class NegativeSampler(Mapper):
...
@@ -102,14 +102,16 @@ class NegativeSampler(Mapper):
data : LinkPredictionBlock
data : LinkPredictionBlock
The input data, which contains positive node pairs, will be filled
The input data, which contains positive node pairs, will be filled
with negative information in this function.
with negative information in this function.
neg_pairs : Tuple[Tensor]
neg_pairs : Tuple[
Tensor,
Tensor]
A tuple of tensors represents source-destination node pairs of
A tuple of tensors represents source-destination node pairs of
negative edges, where negative means the edge may not exist in
negative edges, where negative means the edge may not exist in
the graph.
the graph.
etype :
(
str
, str, str)
etype : str
Canonical edge type.
Canonical edge type.
"""
"""
pos_src
,
pos_dst
=
data
.
node_pair
[
etype
]
if
etype
else
data
.
node_pair
pos_src
,
pos_dst
=
(
data
.
node_pair
[
etype
]
if
etype
is
not
None
else
data
.
node_pair
)
neg_src
,
neg_dst
=
neg_pairs
neg_src
,
neg_dst
=
neg_pairs
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
if
self
.
output_format
==
LinkPredictionEdgeFormat
.
INDEPENDENT
:
pos_label
=
torch
.
ones_like
(
pos_src
)
pos_label
=
torch
.
ones_like
(
pos_src
)
...
@@ -117,7 +119,7 @@ class NegativeSampler(Mapper):
...
@@ -117,7 +119,7 @@ class NegativeSampler(Mapper):
src
=
torch
.
cat
([
pos_src
,
neg_src
])
src
=
torch
.
cat
([
pos_src
,
neg_src
])
dst
=
torch
.
cat
([
pos_dst
,
neg_dst
])
dst
=
torch
.
cat
([
pos_dst
,
neg_dst
])
label
=
torch
.
cat
([
pos_label
,
neg_label
])
label
=
torch
.
cat
([
pos_label
,
neg_label
])
if
etype
:
if
etype
is
not
None
:
data
.
node_pair
[
etype
]
=
(
src
,
dst
)
data
.
node_pair
[
etype
]
=
(
src
,
dst
)
data
.
label
[
etype
]
=
label
data
.
label
[
etype
]
=
label
else
:
else
:
...
@@ -141,7 +143,7 @@ class NegativeSampler(Mapper):
...
@@ -141,7 +143,7 @@ class NegativeSampler(Mapper):
raise
TypeError
(
raise
TypeError
(
f
"Unsupported output format
{
self
.
output_format
}
."
f
"Unsupported output format
{
self
.
output_format
}
."
)
)
if
etype
:
if
etype
is
not
None
:
data
.
negative_head
[
etype
]
=
neg_src
data
.
negative_head
[
etype
]
=
neg_src
data
.
negative_tail
[
etype
]
=
neg_dst
data
.
negative_tail
[
etype
]
=
neg_dst
else
:
else
:
...
...
python/dgl/graphbolt/subgraph_sampler.py
View file @
911c4aba
...
@@ -5,6 +5,7 @@ from typing import Dict
...
@@ -5,6 +5,7 @@ from typing import Dict
from
torchdata.datapipes.iter
import
Mapper
from
torchdata.datapipes.iter
import
Mapper
from
.base
import
etype_str_to_tuple
from
.data_block
import
LinkPredictionBlock
,
NodeClassificationBlock
from
.data_block
import
LinkPredictionBlock
,
NodeClassificationBlock
from
.utils
import
unique_and_compact
from
.utils
import
unique_and_compact
...
@@ -51,14 +52,17 @@ class SubgraphSampler(Mapper):
...
@@ -51,14 +52,17 @@ class SubgraphSampler(Mapper):
if
is_heterogeneous
:
if
is_heterogeneous
:
# Collect nodes from all types of input.
# Collect nodes from all types of input.
nodes
=
defaultdict
(
list
)
nodes
=
defaultdict
(
list
)
for
(
src_type
,
_
,
dst_type
),
(
src
,
dst
)
in
node_pair
.
items
():
for
etype
,
(
src
,
dst
)
in
node_pair
.
items
():
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
nodes
[
src_type
].
append
(
src
)
nodes
[
src_type
].
append
(
src
)
nodes
[
dst_type
].
append
(
dst
)
nodes
[
dst_type
].
append
(
dst
)
if
has_neg_src
:
if
has_neg_src
:
for
(
src_type
,
_
,
_
),
src
in
neg_src
.
items
():
for
etype
,
src
in
neg_src
.
items
():
src_type
,
_
,
_
=
etype_str_to_tuple
(
etype
)
nodes
[
src_type
].
append
(
src
.
view
(
-
1
))
nodes
[
src_type
].
append
(
src
.
view
(
-
1
))
if
has_neg_dst
:
if
has_neg_dst
:
for
(
_
,
_
,
dst_type
),
dst
in
neg_dst
.
items
():
for
etype
,
dst
in
neg_dst
.
items
():
_
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
nodes
[
dst_type
].
append
(
dst
.
view
(
-
1
))
nodes
[
dst_type
].
append
(
dst
.
view
(
-
1
))
# Unique and compact the collected nodes.
# Unique and compact the collected nodes.
seeds
,
compacted
=
unique_and_compact
(
nodes
)
seeds
,
compacted
=
unique_and_compact
(
nodes
)
...
@@ -69,16 +73,18 @@ class SubgraphSampler(Mapper):
...
@@ -69,16 +73,18 @@ class SubgraphSampler(Mapper):
)
=
({},
{},
{})
)
=
({},
{},
{})
# Map back in same order as collect.
# Map back in same order as collect.
for
etype
,
_
in
node_pair
.
items
():
for
etype
,
_
in
node_pair
.
items
():
src_type
,
_
,
dst_type
=
etype
src_type
,
_
,
dst_type
=
etype
_str_to_tuple
(
etype
)
src
=
compacted
[
src_type
].
pop
(
0
)
src
=
compacted
[
src_type
].
pop
(
0
)
dst
=
compacted
[
dst_type
].
pop
(
0
)
dst
=
compacted
[
dst_type
].
pop
(
0
)
compacted_node_pair
[
etype
]
=
(
src
,
dst
)
compacted_node_pair
[
etype
]
=
(
src
,
dst
)
if
has_neg_src
:
if
has_neg_src
:
for
etype
,
_
in
neg_src
.
items
():
for
etype
,
_
in
neg_src
.
items
():
compacted_negative_head
[
etype
]
=
compacted
[
etype
[
0
]].
pop
(
0
)
src_type
,
_
,
_
=
etype_str_to_tuple
(
etype
)
compacted_negative_head
[
etype
]
=
compacted
[
src_type
].
pop
(
0
)
if
has_neg_dst
:
if
has_neg_dst
:
for
etype
,
_
in
neg_dst
.
items
():
for
etype
,
_
in
neg_dst
.
items
():
compacted_negative_tail
[
etype
]
=
compacted
[
etype
[
2
]].
pop
(
0
)
_
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
compacted_negative_tail
[
etype
]
=
compacted
[
dst_type
].
pop
(
0
)
else
:
else
:
# Collect nodes from all types of input.
# Collect nodes from all types of input.
nodes
=
list
(
node_pair
)
nodes
=
list
(
node_pair
)
...
...
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
View file @
911c4aba
...
@@ -184,13 +184,13 @@ def test_NegativeSampler_Hetero_Data(format):
...
@@ -184,13 +184,13 @@ def test_NegativeSampler_Hetero_Data(format):
graph
=
get_hetero_graph
()
graph
=
get_hetero_graph
()
itemset
=
gb
.
ItemSetDict
(
itemset
=
gb
.
ItemSetDict
(
{
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
gb
.
ItemSet
(
"n1
:e1:
n2"
:
gb
.
ItemSet
(
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
]),
torch
.
LongTensor
([
0
,
0
,
1
,
1
]),
torch
.
LongTensor
([
0
,
2
,
0
,
1
]),
torch
.
LongTensor
([
0
,
2
,
0
,
1
]),
)
)
),
),
(
"n2
"
,
"e2"
,
"
n1"
)
:
gb
.
ItemSet
(
"n2
:e2:
n1"
:
gb
.
ItemSet
(
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
2
,
2
]),
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
2
,
2
]),
torch
.
LongTensor
([
0
,
1
,
1
,
0
,
0
,
1
]),
torch
.
LongTensor
([
0
,
1
,
1
,
0
,
0
,
1
]),
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
911c4aba
...
@@ -104,13 +104,13 @@ def test_SubgraphSampler_Link_Hetero(labor):
...
@@ -104,13 +104,13 @@ def test_SubgraphSampler_Link_Hetero(labor):
graph
=
get_hetero_graph
()
graph
=
get_hetero_graph
()
itemset
=
gb
.
ItemSetDict
(
itemset
=
gb
.
ItemSetDict
(
{
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
gb
.
ItemSet
(
"n1
:e1:
n2"
:
gb
.
ItemSet
(
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
]),
torch
.
LongTensor
([
0
,
0
,
1
,
1
]),
torch
.
LongTensor
([
0
,
2
,
0
,
1
]),
torch
.
LongTensor
([
0
,
2
,
0
,
1
]),
)
)
),
),
(
"n2
"
,
"e2"
,
"
n1"
)
:
gb
.
ItemSet
(
"n2
:e2:
n1"
:
gb
.
ItemSet
(
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
2
,
2
]),
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
2
,
2
]),
torch
.
LongTensor
([
0
,
1
,
1
,
0
,
0
,
1
]),
torch
.
LongTensor
([
0
,
1
,
1
,
0
,
0
,
1
]),
...
@@ -142,13 +142,13 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
...
@@ -142,13 +142,13 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
graph
=
get_hetero_graph
()
graph
=
get_hetero_graph
()
itemset
=
gb
.
ItemSetDict
(
itemset
=
gb
.
ItemSetDict
(
{
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
gb
.
ItemSet
(
"n1
:e1:
n2"
:
gb
.
ItemSet
(
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
]),
torch
.
LongTensor
([
0
,
0
,
1
,
1
]),
torch
.
LongTensor
([
0
,
2
,
0
,
1
]),
torch
.
LongTensor
([
0
,
2
,
0
,
1
]),
)
)
),
),
(
"n2
"
,
"e2"
,
"
n1"
)
:
gb
.
ItemSet
(
"n2
:e2:
n1"
:
gb
.
ItemSet
(
(
(
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
2
,
2
]),
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
2
,
2
]),
torch
.
LongTensor
([
0
,
1
,
1
,
0
,
0
,
1
]),
torch
.
LongTensor
([
0
,
1
,
1
,
0
,
0
,
1
]),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment