Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6b99f328
Unverified
Commit
6b99f328
authored
Aug 30, 2023
by
Rhett Ying
Committed by
GitHub
Aug 30, 2023
Browse files
[GraphBolt] use str etype in graphs (#6235)
parent
911c4aba
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
98 additions
and
106 deletions
+98
-106
python/dgl/graphbolt/feature_fetcher.py
python/dgl/graphbolt/feature_fetcher.py
+1
-4
python/dgl/graphbolt/impl/csc_sampling_graph.py
python/dgl/graphbolt/impl/csc_sampling_graph.py
+8
-7
python/dgl/graphbolt/impl/sampled_subgraph_impl.py
python/dgl/graphbolt/impl/sampled_subgraph_impl.py
+20
-22
python/dgl/graphbolt/sampled_subgraph.py
python/dgl/graphbolt/sampled_subgraph.py
+10
-5
python/dgl/graphbolt/utils/sample_utils.py
python/dgl/graphbolt/utils/sample_utils.py
+14
-11
tests/python/pytorch/graphbolt/gb_test_utils.py
tests/python/pytorch/graphbolt/gb_test_utils.py
+1
-1
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
.../python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
+30
-42
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
+1
-1
tests/python/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
...thon/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
+5
-5
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+3
-3
tests/python/pytorch/graphbolt/test_graphbolt_utils.py
tests/python/pytorch/graphbolt/test_graphbolt_utils.py
+4
-4
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+1
-1
No files found.
python/dgl/graphbolt/feature_fetcher.py
View file @
6b99f328
...
...
@@ -66,10 +66,7 @@ class FeatureFetcher(Mapper):
edges
=
(
subgraph
.
reverse_edge_ids
if
not
type_name
# TODO(#6211): Clean up the edge type converter.
else
subgraph
.
reverse_edge_ids
.
get
(
tuple
(
type_name
.
split
(
":"
)),
None
)
else
subgraph
.
reverse_edge_ids
.
get
(
type_name
,
None
)
)
if
edges
is
not
None
:
data
.
edge_feature
[
i
][
...
...
python/dgl/graphbolt/impl/csc_sampling_graph.py
View file @
6b99f328
...
...
@@ -4,14 +4,14 @@ import os
import
tarfile
import
tempfile
from
collections
import
defaultdict
from
typing
import
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
Optional
,
Union
import
torch
from
...base
import
ETYPE
from
...convert
import
to_homogeneous
from
...heterograph
import
DGLGraph
from
..base
import
etype_str_to_tuple
from
..base
import
etype_str_to_tuple
,
etype_tuple_to_str
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
...
...
@@ -21,7 +21,7 @@ class GraphMetadata:
def
__init__
(
self
,
node_type_to_id
:
Dict
[
str
,
int
],
edge_type_to_id
:
Dict
[
Tuple
[
str
,
str
,
str
]
,
int
],
edge_type_to_id
:
Dict
[
str
,
int
],
):
"""Initialize the GraphMetadata object.
...
...
@@ -29,7 +29,7 @@ class GraphMetadata:
----------
node_type_to_id : Dict[str, int]
Dictionary from node types to node type IDs.
edge_type_to_id : Dict[
Tuple[str, str, str]
, int]
edge_type_to_id : Dict[
str
, int]
Dictionary from edge types to edge type IDs.
Raises
...
...
@@ -55,7 +55,7 @@ class GraphMetadata:
),
"Multiple node types shoud not be mapped to a same id."
# Validate edge_type_to_id.
for
edge_type
in
edge_types
:
src
,
edge
,
dst
=
edge_type
src
,
edge
,
dst
=
etype_str_to_tuple
(
edge_type
)
assert
isinstance
(
edge
,
str
),
"Edge type name should be string."
assert
(
src
in
node_types
...
...
@@ -238,7 +238,7 @@ class CSCSamplingGraph:
# converted to heterogeneous graphs.
node_pairs
=
defaultdict
(
list
)
for
etype
,
etype_id
in
self
.
metadata
.
edge_type_to_id
.
items
():
src_ntype
,
_
,
dst_ntype
=
etype
src_ntype
,
_
,
dst_ntype
=
etype
_str_to_tuple
(
etype
)
src_ntype_id
=
self
.
metadata
.
node_type_to_id
[
src_ntype
]
dst_ntype_id
=
self
.
metadata
.
node_type_to_id
[
dst_ntype
]
mask
=
type_per_edge
==
etype_id
...
...
@@ -719,7 +719,8 @@ def from_dglgraph(g: DGLGraph, is_homogeneous=False) -> CSCSamplingGraph:
# Initialize metadata.
node_type_to_id
=
{
ntype
:
g
.
get_ntype_id
(
ntype
)
for
ntype
in
g
.
ntypes
}
edge_type_to_id
=
{
etype
:
g
.
get_etype_id
(
etype
)
for
etype
in
g
.
canonical_etypes
etype_tuple_to_str
(
etype
):
g
.
get_etype_id
(
etype
)
for
etype
in
g
.
canonical_etypes
}
metadata
=
GraphMetadata
(
node_type_to_id
,
edge_type_to_id
)
...
...
python/dgl/graphbolt/impl/sampled_subgraph_impl.py
View file @
6b99f328
...
...
@@ -5,6 +5,7 @@ from typing import Dict, Tuple, Union
import
torch
from
..base
import
etype_str_to_tuple
from
..sampled_subgraph
import
SampledSubgraph
...
...
@@ -14,11 +15,11 @@ class SampledSubgraphImpl(SampledSubgraph):
Examples
--------
>>> node_pairs = {
('A', '
relation
', 'B'
): (torch.tensor([0, 1, 2]),
>>> node_pairs = {
"A:
relation
:B"
): (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))}
>>> reverse_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> reverse_edge_ids = {
('A', '
relation
', 'B')
: torch.tensor([19, 20, 21])}
>>> reverse_edge_ids = {
"A:
relation
:B"
: torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... reverse_column_node_ids=reverse_column_node_ids,
...
...
@@ -26,33 +27,29 @@ class SampledSubgraphImpl(SampledSubgraph):
... reverse_edge_ids=reverse_edge_ids
... )
>>> print(subgraph.node_pairs)
{
('A', '
relation
', 'B')
: (tensor([0, 1, 2]), tensor([0, 1, 2]))}
{
"A:
relation
:B"
: (tensor([0, 1, 2]), tensor([0, 1, 2]))}
>>> print(subgraph.reverse_column_node_ids)
{'B': tensor([10, 11, 12])}
>>> print(subgraph.reverse_row_node_ids)
{'A': tensor([13, 14, 15])}
>>> print(subgraph.reverse_edge_ids)
{
('A', '
relation
', 'B')
: tensor([19, 20, 21])}
{
"A:
relation
:B"
: tensor([19, 20, 21])}
"""
node_pairs
:
Union
[
Dict
[
Tuple
[
str
,
str
,
str
]
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
]
=
None
reverse_column_node_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
reverse_row_node_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
reverse_edge_ids
:
Union
[
Dict
[
Tuple
[
str
,
str
,
str
],
torch
.
Tensor
],
torch
.
Tensor
]
=
None
reverse_edge_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
if
isinstance
(
self
.
node_pairs
,
dict
):
for
etype
,
pair
in
self
.
node_pairs
.
items
():
assert
(
isinstance
(
etype
,
tuple
)
and
len
(
etype
)
==
3
),
"Edge type should be a triplet of strings (str, str, str)."
assert
all
(
isinstance
(
item
,
str
)
for
item
in
etype
),
"Edge type should be a triplet of strings (str, str, str)."
isinstance
(
etype
,
str
)
and
len
(
etype_str_to_tuple
(
etype
))
==
3
),
"Edge type should be a string in format of str:str:str."
assert
(
isinstance
(
pair
,
tuple
)
and
len
(
pair
)
==
2
),
"Node pair should be a source-destination tuple (u, v)."
...
...
@@ -127,7 +124,7 @@ def _slice_subgraph(subgraph: SampledSubgraphImpl, index: torch.Tensor):
def
exclude_edges
(
subgraph
:
SampledSubgraphImpl
,
edges
:
Union
[
Dict
[
Tuple
[
str
,
str
,
str
]
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
],
)
->
SampledSubgraphImpl
:
...
...
@@ -142,7 +139,7 @@ def exclude_edges(
----------
subgraph : SampledSubgraphImpl
The sampled subgraph.
edges : Union[Dict[
Tuple[str, str, str]
, Tuple[torch.Tensor, torch.Tensor]],
edges : Union[Dict[
str
, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor]]
Edges to exclude. If sampled subgraph is homogeneous, then `edges`
should be a pair of tensors representing the edges to exclude. If
...
...
@@ -156,11 +153,11 @@ def exclude_edges(
Examples
--------
>>> node_pairs = {
('A', '
relation
', 'B')
: (torch.tensor([0, 1, 2]),
>>> node_pairs = {
"A:
relation
:B"
: (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))}
>>> reverse_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> reverse_edge_ids = {
('A', '
relation
', 'B')
: torch.tensor([19, 20, 21])}
>>> reverse_edge_ids = {
"A:
relation
:B"
: torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... reverse_column_node_ids=reverse_column_node_ids,
...
...
@@ -170,13 +167,13 @@ def exclude_edges(
>>> exclude_edges = (torch.tensor([14, 15]), torch.tensor([11, 12]))
>>> result = gb.exclude_edges(subgraph, exclude_edges)
>>> print(result.node_pairs)
{
('A', '
relation
', 'B')
: (tensor([0]), tensor([0]))}
{
"A:
relation
:B"
: (tensor([0]), tensor([0]))}
>>> print(result.reverse_column_node_ids)
{'B': tensor([10, 11, 12])}
>>> print(result.reverse_row_node_ids)
{'A': tensor([13, 14, 15])}
>>> print(result.reverse_edge_ids)
{
('A', '
relation
', 'B')
: tensor([19])}
{
"A:
relation
:B"
: tensor([19])}
"""
assert
isinstance
(
subgraph
.
node_pairs
,
tuple
)
==
isinstance
(
edges
,
tuple
),
(
"The sampled subgraph and the edges to exclude should be both "
...
...
@@ -197,15 +194,16 @@ def exclude_edges(
else
:
index
=
{}
for
etype
,
pair
in
subgraph
.
node_pairs
.
items
():
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
reverse_row_node_ids
=
(
None
if
subgraph
.
reverse_row_node_ids
is
None
else
subgraph
.
reverse_row_node_ids
.
get
(
e
type
[
0
]
)
else
subgraph
.
reverse_row_node_ids
.
get
(
src_
type
)
)
reverse_column_node_ids
=
(
None
if
subgraph
.
reverse_column_node_ids
is
None
else
subgraph
.
reverse_column_node_ids
.
get
(
e
type
[
2
]
)
else
subgraph
.
reverse_column_node_ids
.
get
(
dst_
type
)
)
reverse_edges
=
_to_reverse_ids
(
pair
,
...
...
python/dgl/graphbolt/sampled_subgraph.py
View file @
6b99f328
"""Graphbolt sampled subgraph."""
# pylint: disable= invalid-name
from
typing
import
Dict
,
Tuple
from
typing
import
Dict
,
Tuple
,
Union
import
torch
...
...
@@ -14,7 +14,10 @@ class SampledSubgraph:
@
property
def
node_pairs
(
self
,
)
->
Tuple
[
torch
.
Tensor
]
or
Dict
[(
str
,
str
,
str
),
Tuple
[
torch
.
Tensor
]]:
)
->
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
]:
"""Returns the node pairs representing source-destination edges.
- If `node_pairs` is a tuple: It should be in the format ('u', 'v')
representing source and destination pairs.
...
...
@@ -26,7 +29,7 @@ class SampledSubgraph:
@
property
def
reverse_column_node_ids
(
self
,
)
->
torch
.
Tensor
or
Dict
[
str
,
torch
.
Tensor
]:
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]
]
:
"""Returns corresponding reverse column node ids the original graph.
Column's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is
...
...
@@ -42,7 +45,9 @@ class SampledSubgraph:
return
None
@
property
def
reverse_row_node_ids
(
self
)
->
torch
.
Tensor
or
Dict
[
str
,
torch
.
Tensor
]:
def
reverse_row_node_ids
(
self
,
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""Returns corresponding reverse row node ids the original graph.
Row's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is
...
...
@@ -57,7 +62,7 @@ class SampledSubgraph:
return
None
@
property
def
reverse_edge_ids
(
self
)
->
torch
.
Tensor
or
Dict
[
str
,
torch
.
Tensor
]:
def
reverse_edge_ids
(
self
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]
]
:
"""Returns corresponding reverse edge ids the original graph.
Reverse edge ids in the original graph. This is useful when edge
features are needed.
...
...
python/dgl/graphbolt/utils/sample_utils.py
View file @
6b99f328
...
...
@@ -5,6 +5,8 @@ from typing import Dict, List, Tuple, Union
import
torch
from
..base
import
etype_str_to_tuple
def
unique_and_compact
(
nodes
:
Union
[
...
...
@@ -61,7 +63,7 @@ def unique_and_compact(
def
unique_and_compact_node_pairs
(
node_pairs
:
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
,
str
]
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
],
unique_dst_nodes
:
Union
[
torch
.
Tensor
,
...
...
@@ -73,8 +75,8 @@ def unique_and_compact_node_pairs(
Parameters
----------
node_pairs : Tuple[torch.Tensor, torch.Tensor]
or
\
Dict(
Tuple[str, str, str]
, Tuple[torch.Tensor, torch.Tensor])
node_pairs :
Union[
Tuple[torch.Tensor, torch.Tensor]
,
Dict(
str
, Tuple[torch.Tensor, torch.Tensor])
]
Node pairs representing source-destination edges.
- If `node_pairs` is a tuple: It means the graph is homogeneous.
Also, it should be in the format ('u', 'v') representing source
...
...
@@ -102,20 +104,20 @@ def unique_and_compact_node_pairs(
>>> import dgl.graphbolt as gb
>>> N1 = torch.LongTensor([1, 2, 2])
>>> N2 = torch.LongTensor([5, 6, 5])
>>> node_pairs = {
(
"n1
", "e1", "
n2"
)
: (N1, N2),
...
(
"n2
", "e2", "
n1"
)
: (N2, N1)}
>>> node_pairs = {"n1
:e1:
n2": (N1, N2),
... "n2
:e2:
n1": (N2, N1)}
>>> unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
... node_pairs
... )
>>> print(unique_nodes)
{'n1': tensor([1, 2]), 'n2': tensor([5, 6])}
>>> print(compacted_node_pairs)
{
('n1', 'e1', 'n2')
: (tensor([0, 1, 1]), tensor([0, 1, 0])),
('n2', 'e2', 'n1')
: (tensor([0, 1, 0]), tensor([0, 1, 1]))}
{
"n1:e1:n2"
: (tensor([0, 1, 1]), tensor([0, 1, 0])),
"n2:e2:n1"
: (tensor([0, 1, 0]), tensor([0, 1, 1]))}
"""
is_homogeneous
=
not
isinstance
(
node_pairs
,
dict
)
if
is_homogeneous
:
node_pairs
=
{
(
"_N
"
,
"_E"
,
"
_N"
)
:
node_pairs
}
node_pairs
=
{
"_N
:_E:
_N"
:
node_pairs
}
if
unique_dst_nodes
is
not
None
:
assert
isinstance
(
unique_dst_nodes
,
torch
.
Tensor
...
...
@@ -126,8 +128,9 @@ def unique_and_compact_node_pairs(
src_nodes
=
defaultdict
(
list
)
dst_nodes
=
defaultdict
(
list
)
for
etype
,
(
src_node
,
dst_node
)
in
node_pairs
.
items
():
src_nodes
[
etype
[
0
]].
append
(
src_node
)
dst_nodes
[
etype
[
2
]].
append
(
dst_node
)
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
src_nodes
[
src_type
].
append
(
src_node
)
dst_nodes
[
dst_type
].
append
(
dst_node
)
src_nodes
=
{
ntype
:
torch
.
cat
(
nodes
)
for
ntype
,
nodes
in
src_nodes
.
items
()}
dst_nodes
=
{
ntype
:
torch
.
cat
(
nodes
)
for
ntype
,
nodes
in
dst_nodes
.
items
()}
# Compute unique destination nodes if not provided.
...
...
@@ -156,7 +159,7 @@ def unique_and_compact_node_pairs(
# Map back with the same order.
for
etype
,
pair
in
node_pairs
.
items
():
num_elem
=
pair
[
0
].
size
(
0
)
src_type
,
_
,
dst_type
=
etype
src_type
,
_
,
dst_type
=
etype
_str_to_tuple
(
etype
)
src
=
compacted_src
[
src_type
][:
num_elem
]
dst
=
compacted_dst
[
dst_type
][:
num_elem
]
compacted_node_pairs
[
etype
]
=
(
src
,
dst
)
...
...
tests/python/pytorch/graphbolt/gb_test_utils.py
View file @
6b99f328
...
...
@@ -43,7 +43,7 @@ def get_metadata(num_ntypes, num_etypes):
for
n2
in
range
(
n1
,
num_ntypes
):
if
count
>=
num_etypes
:
break
etypes
.
update
({
(
f
"n
{
n1
}
"
,
f
"
e
{
count
}
"
,
f
"
n
{
n2
}
"
)
:
count
})
etypes
.
update
({
f
"n
{
n1
}
:
e
{
count
}
:
n
{
n2
}
"
:
count
})
count
+=
1
return
gb
.
GraphMetadata
(
ntypes
,
etypes
)
...
...
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
View file @
6b99f328
...
...
@@ -73,7 +73,7 @@ def test_hetero_empty_graph(num_nodes):
)
def
test_metadata_with_ntype_exception
(
ntypes
):
with
pytest
.
raises
(
Exception
):
gb
.
GraphMetadata
(
ntypes
,
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
1
})
gb
.
GraphMetadata
(
ntypes
,
{
"n1
:e1:
n2"
:
1
})
@
unittest
.
skipIf
(
...
...
@@ -87,9 +87,9 @@ def test_metadata_with_ntype_exception(ntypes):
{
"e1"
:
1
},
{(
"n1"
,
"e1"
):
1
},
{(
"n1"
,
"e1"
,
10
):
1
},
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
1
,
(
"n1"
,
"e2"
,
"n3"
):
1
},
{
"n1
:e1:
n2"
:
1
,
(
"n1"
,
"e2"
,
"n3"
):
1
},
{(
"n1"
,
"e1"
,
"n10"
):
1
},
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
1.5
},
{
"n1
:e1:
n2"
:
1.5
},
],
)
def
test_metadata_with_etype_exception
(
etypes
):
...
...
@@ -320,10 +320,10 @@ def test_in_subgraph_heterogeneous():
"N1"
:
1
,
}
etypes
=
{
(
"N0
"
,
"R0"
,
"
N0"
)
:
0
,
(
"N0
"
,
"R1"
,
"
N1"
)
:
1
,
(
"N1
"
,
"R2"
,
"
N0"
)
:
2
,
(
"N1
"
,
"R3"
,
"
N1"
)
:
3
,
"N0
:R0:
N0"
:
0
,
"N0
:R1:
N1"
:
1
,
"N1
:R2:
N0"
:
2
,
"N1
:R3:
N1"
:
3
,
}
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
...
...
@@ -403,8 +403,8 @@ def test_sample_neighbors_homo():
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_sample_neighbors_hetero
(
labor
):
"""Original graph in COO:
(
"n1
", "e1", "
n2"
)
:[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
(
"n2
", "e2", "
n1"
)
:[0, 0, 1, 2], [0, 1, 1 ,0]
"n1
:e1:
n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2
:e2:
n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
...
...
@@ -413,7 +413,7 @@ def test_sample_neighbors_hetero(labor):
"""
# Initialize data.
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
0
,
(
"n2
"
,
"e2"
,
"
n1"
)
:
1
}
etypes
=
{
"n1
:e1:
n2"
:
0
,
"n2
:e2:
n1"
:
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
num_nodes
=
5
num_edges
=
9
...
...
@@ -441,11 +441,11 @@ def test_sample_neighbors_hetero(labor):
# Verify in subgraph.
expected_node_pairs
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
(
"n1
:e1:
n2"
:
(
torch
.
LongTensor
([
0
,
1
]),
torch
.
LongTensor
([
0
,
0
]),
),
(
"n2
"
,
"e2"
,
"
n1"
)
:
(
"n2
:e2:
n1"
:
(
torch
.
LongTensor
([
0
,
2
]),
torch
.
LongTensor
([
0
,
0
]),
),
...
...
@@ -484,8 +484,8 @@ def test_sample_neighbors_fanouts(
fanouts
,
expected_sampled_num1
,
expected_sampled_num2
,
labor
):
"""Original graph in COO:
(
"n1
", "e1", "
n2"
)
:[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
(
"n2
", "e2", "
n1"
)
:[0, 0, 1, 2], [0, 1, 1 ,0]
"n1
:e1:
n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2
:e2:
n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
...
...
@@ -494,7 +494,7 @@ def test_sample_neighbors_fanouts(
"""
# Initialize data.
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
0
,
(
"n2
"
,
"e2"
,
"
n1"
)
:
1
}
etypes
=
{
"n1
:e1:
n2"
:
0
,
"n2
:e2:
n1"
:
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
num_nodes
=
5
num_edges
=
9
...
...
@@ -520,14 +520,8 @@ def test_sample_neighbors_fanouts(
subgraph
=
sampler
(
nodes
,
fanouts
)
# Verify in subgraph.
assert
(
subgraph
.
node_pairs
[(
"n1"
,
"e1"
,
"n2"
)][
0
].
numel
()
==
expected_sampled_num1
)
assert
(
subgraph
.
node_pairs
[(
"n2"
,
"e2"
,
"n1"
)][
0
].
numel
()
==
expected_sampled_num2
)
assert
subgraph
.
node_pairs
[
"n1:e1:n2"
][
0
].
numel
()
==
expected_sampled_num1
assert
subgraph
.
node_pairs
[
"n2:e2:n1"
][
0
].
numel
()
==
expected_sampled_num2
@
unittest
.
skipIf
(
...
...
@@ -542,8 +536,8 @@ def test_sample_neighbors_replace(
replace
,
expected_sampled_num1
,
expected_sampled_num2
):
"""Original graph in COO:
(
"n1
", "e1", "
n2"
)
:[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
(
"n2
", "e2", "
n1"
)
:[0, 0, 1, 2], [0, 1, 1 ,0]
"n1
:e1:
n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2
:e2:
n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
...
...
@@ -552,7 +546,7 @@ def test_sample_neighbors_replace(
"""
# Initialize data.
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
0
,
(
"n2
"
,
"e2"
,
"
n1"
)
:
1
}
etypes
=
{
"n1
:e1:
n2"
:
0
,
"n2
:e2:
n1"
:
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
num_nodes
=
5
num_edges
=
9
...
...
@@ -578,14 +572,8 @@ def test_sample_neighbors_replace(
)
# Verify in subgraph.
assert
(
subgraph
.
node_pairs
[(
"n1"
,
"e1"
,
"n2"
)][
0
].
numel
()
==
expected_sampled_num1
)
assert
(
subgraph
.
node_pairs
[(
"n2"
,
"e2"
,
"n1"
)][
0
].
numel
()
==
expected_sampled_num2
)
assert
subgraph
.
node_pairs
[
"n1:e1:n2"
][
0
].
numel
()
==
expected_sampled_num1
assert
subgraph
.
node_pairs
[
"n2:e2:n1"
][
0
].
numel
()
==
expected_sampled_num2
@
unittest
.
skipIf
(
...
...
@@ -811,7 +799,7 @@ def test_from_dglgraph_homogeneous():
assert
torch
.
equal
(
gb_g
.
node_type_offset
,
torch
.
tensor
([
0
,
1000
]))
assert
torch
.
all
(
gb_g
.
type_per_edge
==
0
)
assert
gb_g
.
metadata
.
node_type_to_id
==
{
"_N"
:
0
}
assert
gb_g
.
metadata
.
edge_type_to_id
==
{
(
"_N
"
,
"_E"
,
"
_N"
)
:
0
}
assert
gb_g
.
metadata
.
edge_type_to_id
==
{
"_N
:_E:
_N"
:
0
}
@
unittest
.
skipIf
(
...
...
@@ -855,10 +843,10 @@ def test_from_dglgraph_heterogeneous():
"n3"
:
2
,
}
assert
gb_g
.
metadata
.
edge_type_to_id
==
{
(
"n1
"
,
"r12"
,
"
n2"
)
:
0
,
(
"n1
"
,
"r13"
,
"
n3"
)
:
1
,
(
"n2
"
,
"r21"
,
"
n1"
)
:
2
,
(
"n2
"
,
"r23"
,
"
n3"
)
:
3
,
"n1
:r12:
n2"
:
0
,
"n1
:r13:
n3"
:
1
,
"n2
:r21:
n1"
:
2
,
"n2
:r23:
n3"
:
3
,
}
...
...
@@ -972,9 +960,9 @@ def test_sample_neighbors_hetero_pick_number(
num_edges
=
9
ntypes
=
{
"N0"
:
0
,
"N1"
:
1
,
"N2"
:
2
,
"N3"
:
3
}
etypes
=
{
(
"N0
"
,
"R0"
,
"
N1"
)
:
0
,
(
"N0
"
,
"R1"
,
"
N2"
)
:
1
,
(
"N0
"
,
"R2"
,
"
N3"
)
:
2
,
"N0
:R0:
N1"
:
0
,
"N0
:R1:
N2"
:
1
,
"N0
:R2:
N3"
:
2
,
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
indptr
=
torch
.
LongTensor
([
0
,
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
])
...
...
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
View file @
6b99f328
...
...
@@ -151,7 +151,7 @@ def get_hetero_graph():
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
0
,
(
"n2
"
,
"e2"
,
"
n1"
)
:
1
}
etypes
=
{
"n1
:e1:
n2"
:
0
,
"n2
:e2:
n1"
:
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
indptr
=
torch
.
LongTensor
([
0
,
2
,
4
,
6
,
8
,
10
])
indices
=
torch
.
LongTensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
0
,
1
])
...
...
tests/python/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
View file @
6b99f328
...
...
@@ -73,7 +73,7 @@ def test_exclude_edges_homo(reverse_row, reverse_column):
@
pytest
.
mark
.
parametrize
(
"reverse_column"
,
[
True
,
False
])
def
test_exclude_edges_hetero
(
reverse_row
,
reverse_column
):
node_pairs
=
{
(
"A
"
,
"
relation
"
,
"
B"
)
:
(
"A
:
relation
:
B"
:
(
torch
.
tensor
([
0
,
1
,
2
]),
torch
.
tensor
([
2
,
1
,
0
]),
)
...
...
@@ -94,7 +94,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
else
:
reverse_column_node_ids
=
None
dst_to_exclude
=
torch
.
tensor
([
0
,
2
])
reverse_edge_ids
=
{
(
"A
"
,
"
relation
"
,
"
B"
)
:
torch
.
tensor
([
19
,
20
,
21
])}
reverse_edge_ids
=
{
"A
:
relation
:
B"
:
torch
.
tensor
([
19
,
20
,
21
])}
subgraph
=
SampledSubgraphImpl
(
node_pairs
=
node_pairs
,
reverse_column_node_ids
=
reverse_column_node_ids
,
...
...
@@ -103,14 +103,14 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
)
edges_to_exclude
=
{
(
"A
"
,
"
relation
"
,
"
B"
)
:
(
"A
:
relation
:
B"
:
(
src_to_exclude
,
dst_to_exclude
,
)
}
result
=
exclude_edges
(
subgraph
,
edges_to_exclude
)
expected_node_pairs
=
{
(
"A
"
,
"
relation
"
,
"
B"
)
:
(
"A
:
relation
:
B"
:
(
torch
.
tensor
([
1
]),
torch
.
tensor
([
1
]),
)
...
...
@@ -127,7 +127,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
}
else
:
expected_column_node_ids
=
None
expected_edge_ids
=
{
(
"A
"
,
"
relation
"
,
"
B"
)
:
torch
.
tensor
([
20
])}
expected_edge_ids
=
{
"A
:
relation
:
B"
:
torch
.
tensor
([
20
])}
_assert_container_equal
(
result
.
node_pairs
,
expected_node_pairs
)
_assert_container_equal
(
...
...
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
6b99f328
...
...
@@ -71,7 +71,7 @@ def get_hetero_graph():
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
0
,
(
"n2
"
,
"e2"
,
"
n1"
)
:
1
}
etypes
=
{
"n1
:e1:
n2"
:
0
,
"n2
:e2:
n1"
:
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
indptr
=
torch
.
LongTensor
([
0
,
2
,
4
,
6
,
8
,
10
])
indices
=
torch
.
LongTensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
0
,
1
])
...
...
@@ -120,8 +120,8 @@ def test_FeatureFetcher_with_edges_hetero():
def
add_node_and_edge_ids
(
seeds
):
subgraphs
=
[]
reverse_edge_ids
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
torch
.
randint
(
0
,
50
,
(
10
,)),
(
"n2
"
,
"e2"
,
"
n1"
)
:
torch
.
randint
(
0
,
50
,
(
10
,)),
"n1
:e1:
n2"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
"n2
:e2:
n1"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
}
for
_
in
range
(
3
):
subgraphs
.
append
(
...
...
tests/python/pytorch/graphbolt/test_graphbolt_utils.py
View file @
6b99f328
...
...
@@ -62,15 +62,15 @@ def test_unique_and_compact_node_pairs_hetero():
"n3"
:
unique_N3
,
}
node_pairs
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
(
"n1
:e1:
n2"
:
(
N1
[:
20
],
N2
,
),
(
"n1
"
,
"e2"
,
"
n3"
)
:
(
"n1
:e2:
n3"
:
(
N1
[
20
:
30
],
N3
,
),
(
"n2
"
,
"e3"
,
"
n3"
)
:
(
"n2
:e3:
n3"
:
(
N2
[
10
:],
N3
,
),
...
...
@@ -84,7 +84,7 @@ def test_unique_and_compact_node_pairs_hetero():
assert
torch
.
equal
(
torch
.
sort
(
nodes
)[
0
],
expected_nodes
)
for
etype
,
pair
in
compacted_node_pairs
.
items
():
u
,
v
=
pair
u_type
,
_
,
v_type
=
etype
u_type
,
_
,
v_type
=
gb
.
etype
_str_to_tuple
(
etype
)
u
,
v
=
unique_nodes
[
u_type
][
u
],
unique_nodes
[
v_type
][
v
]
expected_u
,
expected_v
=
node_pairs
[
etype
]
assert
torch
.
equal
(
u
,
expected_u
)
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
6b99f328
...
...
@@ -84,7 +84,7 @@ def get_hetero_graph():
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
0
,
(
"n2
"
,
"e2"
,
"
n1"
)
:
1
}
etypes
=
{
"n1
:e1:
n2"
:
0
,
"n2
:e2:
n1"
:
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
indptr
=
torch
.
LongTensor
([
0
,
2
,
4
,
6
,
8
,
10
])
indices
=
torch
.
LongTensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
0
,
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment