Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6b99f328
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e4b056fe652536ac89ff2c98e36b2d3685cbccd2"
Unverified
Commit
6b99f328
authored
Aug 30, 2023
by
Rhett Ying
Committed by
GitHub
Aug 30, 2023
Browse files
[GraphBolt] use str etype in graphs (#6235)
parent
911c4aba
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
98 additions
and
106 deletions
+98
-106
python/dgl/graphbolt/feature_fetcher.py
python/dgl/graphbolt/feature_fetcher.py
+1
-4
python/dgl/graphbolt/impl/csc_sampling_graph.py
python/dgl/graphbolt/impl/csc_sampling_graph.py
+8
-7
python/dgl/graphbolt/impl/sampled_subgraph_impl.py
python/dgl/graphbolt/impl/sampled_subgraph_impl.py
+20
-22
python/dgl/graphbolt/sampled_subgraph.py
python/dgl/graphbolt/sampled_subgraph.py
+10
-5
python/dgl/graphbolt/utils/sample_utils.py
python/dgl/graphbolt/utils/sample_utils.py
+14
-11
tests/python/pytorch/graphbolt/gb_test_utils.py
tests/python/pytorch/graphbolt/gb_test_utils.py
+1
-1
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
.../python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
+30
-42
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
+1
-1
tests/python/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
...thon/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
+5
-5
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+3
-3
tests/python/pytorch/graphbolt/test_graphbolt_utils.py
tests/python/pytorch/graphbolt/test_graphbolt_utils.py
+4
-4
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+1
-1
No files found.
python/dgl/graphbolt/feature_fetcher.py
View file @
6b99f328
...
@@ -66,10 +66,7 @@ class FeatureFetcher(Mapper):
...
@@ -66,10 +66,7 @@ class FeatureFetcher(Mapper):
edges
=
(
edges
=
(
subgraph
.
reverse_edge_ids
subgraph
.
reverse_edge_ids
if
not
type_name
if
not
type_name
# TODO(#6211): Clean up the edge type converter.
else
subgraph
.
reverse_edge_ids
.
get
(
type_name
,
None
)
else
subgraph
.
reverse_edge_ids
.
get
(
tuple
(
type_name
.
split
(
":"
)),
None
)
)
)
if
edges
is
not
None
:
if
edges
is
not
None
:
data
.
edge_feature
[
i
][
data
.
edge_feature
[
i
][
...
...
python/dgl/graphbolt/impl/csc_sampling_graph.py
View file @
6b99f328
...
@@ -4,14 +4,14 @@ import os
...
@@ -4,14 +4,14 @@ import os
import
tarfile
import
tarfile
import
tempfile
import
tempfile
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
Optional
,
Union
import
torch
import
torch
from
...base
import
ETYPE
from
...base
import
ETYPE
from
...convert
import
to_homogeneous
from
...convert
import
to_homogeneous
from
...heterograph
import
DGLGraph
from
...heterograph
import
DGLGraph
from
..base
import
etype_str_to_tuple
from
..base
import
etype_str_to_tuple
,
etype_tuple_to_str
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
...
@@ -21,7 +21,7 @@ class GraphMetadata:
...
@@ -21,7 +21,7 @@ class GraphMetadata:
def
__init__
(
def
__init__
(
self
,
self
,
node_type_to_id
:
Dict
[
str
,
int
],
node_type_to_id
:
Dict
[
str
,
int
],
edge_type_to_id
:
Dict
[
Tuple
[
str
,
str
,
str
]
,
int
],
edge_type_to_id
:
Dict
[
str
,
int
],
):
):
"""Initialize the GraphMetadata object.
"""Initialize the GraphMetadata object.
...
@@ -29,7 +29,7 @@ class GraphMetadata:
...
@@ -29,7 +29,7 @@ class GraphMetadata:
----------
----------
node_type_to_id : Dict[str, int]
node_type_to_id : Dict[str, int]
Dictionary from node types to node type IDs.
Dictionary from node types to node type IDs.
edge_type_to_id : Dict[
Tuple[str, str, str]
, int]
edge_type_to_id : Dict[
str
, int]
Dictionary from edge types to edge type IDs.
Dictionary from edge types to edge type IDs.
Raises
Raises
...
@@ -55,7 +55,7 @@ class GraphMetadata:
...
@@ -55,7 +55,7 @@ class GraphMetadata:
),
"Multiple node types shoud not be mapped to a same id."
),
"Multiple node types shoud not be mapped to a same id."
# Validate edge_type_to_id.
# Validate edge_type_to_id.
for
edge_type
in
edge_types
:
for
edge_type
in
edge_types
:
src
,
edge
,
dst
=
edge_type
src
,
edge
,
dst
=
etype_str_to_tuple
(
edge_type
)
assert
isinstance
(
edge
,
str
),
"Edge type name should be string."
assert
isinstance
(
edge
,
str
),
"Edge type name should be string."
assert
(
assert
(
src
in
node_types
src
in
node_types
...
@@ -238,7 +238,7 @@ class CSCSamplingGraph:
...
@@ -238,7 +238,7 @@ class CSCSamplingGraph:
# converted to heterogeneous graphs.
# converted to heterogeneous graphs.
node_pairs
=
defaultdict
(
list
)
node_pairs
=
defaultdict
(
list
)
for
etype
,
etype_id
in
self
.
metadata
.
edge_type_to_id
.
items
():
for
etype
,
etype_id
in
self
.
metadata
.
edge_type_to_id
.
items
():
src_ntype
,
_
,
dst_ntype
=
etype
src_ntype
,
_
,
dst_ntype
=
etype
_str_to_tuple
(
etype
)
src_ntype_id
=
self
.
metadata
.
node_type_to_id
[
src_ntype
]
src_ntype_id
=
self
.
metadata
.
node_type_to_id
[
src_ntype
]
dst_ntype_id
=
self
.
metadata
.
node_type_to_id
[
dst_ntype
]
dst_ntype_id
=
self
.
metadata
.
node_type_to_id
[
dst_ntype
]
mask
=
type_per_edge
==
etype_id
mask
=
type_per_edge
==
etype_id
...
@@ -719,7 +719,8 @@ def from_dglgraph(g: DGLGraph, is_homogeneous=False) -> CSCSamplingGraph:
...
@@ -719,7 +719,8 @@ def from_dglgraph(g: DGLGraph, is_homogeneous=False) -> CSCSamplingGraph:
# Initialize metadata.
# Initialize metadata.
node_type_to_id
=
{
ntype
:
g
.
get_ntype_id
(
ntype
)
for
ntype
in
g
.
ntypes
}
node_type_to_id
=
{
ntype
:
g
.
get_ntype_id
(
ntype
)
for
ntype
in
g
.
ntypes
}
edge_type_to_id
=
{
edge_type_to_id
=
{
etype
:
g
.
get_etype_id
(
etype
)
for
etype
in
g
.
canonical_etypes
etype_tuple_to_str
(
etype
):
g
.
get_etype_id
(
etype
)
for
etype
in
g
.
canonical_etypes
}
}
metadata
=
GraphMetadata
(
node_type_to_id
,
edge_type_to_id
)
metadata
=
GraphMetadata
(
node_type_to_id
,
edge_type_to_id
)
...
...
python/dgl/graphbolt/impl/sampled_subgraph_impl.py
View file @
6b99f328
...
@@ -5,6 +5,7 @@ from typing import Dict, Tuple, Union
...
@@ -5,6 +5,7 @@ from typing import Dict, Tuple, Union
import
torch
import
torch
from
..base
import
etype_str_to_tuple
from
..sampled_subgraph
import
SampledSubgraph
from
..sampled_subgraph
import
SampledSubgraph
...
@@ -14,11 +15,11 @@ class SampledSubgraphImpl(SampledSubgraph):
...
@@ -14,11 +15,11 @@ class SampledSubgraphImpl(SampledSubgraph):
Examples
Examples
--------
--------
>>> node_pairs = {
('A', '
relation
', 'B'
): (torch.tensor([0, 1, 2]),
>>> node_pairs = {
"A:
relation
:B"
): (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))}
... torch.tensor([0, 1, 2]))}
>>> reverse_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> reverse_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> reverse_edge_ids = {
('A', '
relation
', 'B')
: torch.tensor([19, 20, 21])}
>>> reverse_edge_ids = {
"A:
relation
:B"
: torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... node_pairs=node_pairs,
... reverse_column_node_ids=reverse_column_node_ids,
... reverse_column_node_ids=reverse_column_node_ids,
...
@@ -26,33 +27,29 @@ class SampledSubgraphImpl(SampledSubgraph):
...
@@ -26,33 +27,29 @@ class SampledSubgraphImpl(SampledSubgraph):
... reverse_edge_ids=reverse_edge_ids
... reverse_edge_ids=reverse_edge_ids
... )
... )
>>> print(subgraph.node_pairs)
>>> print(subgraph.node_pairs)
{
('A', '
relation
', 'B')
: (tensor([0, 1, 2]), tensor([0, 1, 2]))}
{
"A:
relation
:B"
: (tensor([0, 1, 2]), tensor([0, 1, 2]))}
>>> print(subgraph.reverse_column_node_ids)
>>> print(subgraph.reverse_column_node_ids)
{'B': tensor([10, 11, 12])}
{'B': tensor([10, 11, 12])}
>>> print(subgraph.reverse_row_node_ids)
>>> print(subgraph.reverse_row_node_ids)
{'A': tensor([13, 14, 15])}
{'A': tensor([13, 14, 15])}
>>> print(subgraph.reverse_edge_ids)
>>> print(subgraph.reverse_edge_ids)
{
('A', '
relation
', 'B')
: tensor([19, 20, 21])}
{
"A:
relation
:B"
: tensor([19, 20, 21])}
"""
"""
node_pairs
:
Union
[
node_pairs
:
Union
[
Dict
[
Tuple
[
str
,
str
,
str
]
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
]
=
None
]
=
None
reverse_column_node_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
reverse_column_node_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
reverse_row_node_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
reverse_row_node_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
reverse_edge_ids
:
Union
[
reverse_edge_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
Dict
[
Tuple
[
str
,
str
,
str
],
torch
.
Tensor
],
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
isinstance
(
self
.
node_pairs
,
dict
):
if
isinstance
(
self
.
node_pairs
,
dict
):
for
etype
,
pair
in
self
.
node_pairs
.
items
():
for
etype
,
pair
in
self
.
node_pairs
.
items
():
assert
(
assert
(
isinstance
(
etype
,
tuple
)
and
len
(
etype
)
==
3
isinstance
(
etype
,
str
)
),
"Edge type should be a triplet of strings (str, str, str)."
and
len
(
etype_str_to_tuple
(
etype
))
==
3
assert
all
(
),
"Edge type should be a string in format of str:str:str."
isinstance
(
item
,
str
)
for
item
in
etype
),
"Edge type should be a triplet of strings (str, str, str)."
assert
(
assert
(
isinstance
(
pair
,
tuple
)
and
len
(
pair
)
==
2
isinstance
(
pair
,
tuple
)
and
len
(
pair
)
==
2
),
"Node pair should be a source-destination tuple (u, v)."
),
"Node pair should be a source-destination tuple (u, v)."
...
@@ -127,7 +124,7 @@ def _slice_subgraph(subgraph: SampledSubgraphImpl, index: torch.Tensor):
...
@@ -127,7 +124,7 @@ def _slice_subgraph(subgraph: SampledSubgraphImpl, index: torch.Tensor):
def
exclude_edges
(
def
exclude_edges
(
subgraph
:
SampledSubgraphImpl
,
subgraph
:
SampledSubgraphImpl
,
edges
:
Union
[
edges
:
Union
[
Dict
[
Tuple
[
str
,
str
,
str
]
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
],
],
)
->
SampledSubgraphImpl
:
)
->
SampledSubgraphImpl
:
...
@@ -142,8 +139,8 @@ def exclude_edges(
...
@@ -142,8 +139,8 @@ def exclude_edges(
----------
----------
subgraph : SampledSubgraphImpl
subgraph : SampledSubgraphImpl
The sampled subgraph.
The sampled subgraph.
edges : Union[Dict[
Tuple[str, str, str]
, Tuple[torch.Tensor, torch.Tensor]],
edges : Union[Dict[
str
, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor]]
Tuple[torch.Tensor, torch.Tensor]]
Edges to exclude. If sampled subgraph is homogeneous, then `edges`
Edges to exclude. If sampled subgraph is homogeneous, then `edges`
should be a pair of tensors representing the edges to exclude. If
should be a pair of tensors representing the edges to exclude. If
sampled subgraph is heterogeneous, then `edges` should be a dictionary
sampled subgraph is heterogeneous, then `edges` should be a dictionary
...
@@ -156,11 +153,11 @@ def exclude_edges(
...
@@ -156,11 +153,11 @@ def exclude_edges(
Examples
Examples
--------
--------
>>> node_pairs = {
('A', '
relation
', 'B')
: (torch.tensor([0, 1, 2]),
>>> node_pairs = {
"A:
relation
:B"
: (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))}
... torch.tensor([0, 1, 2]))}
>>> reverse_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> reverse_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> reverse_edge_ids = {
('A', '
relation
', 'B')
: torch.tensor([19, 20, 21])}
>>> reverse_edge_ids = {
"A:
relation
:B"
: torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... node_pairs=node_pairs,
... reverse_column_node_ids=reverse_column_node_ids,
... reverse_column_node_ids=reverse_column_node_ids,
...
@@ -170,13 +167,13 @@ def exclude_edges(
...
@@ -170,13 +167,13 @@ def exclude_edges(
>>> exclude_edges = (torch.tensor([14, 15]), torch.tensor([11, 12]))
>>> exclude_edges = (torch.tensor([14, 15]), torch.tensor([11, 12]))
>>> result = gb.exclude_edges(subgraph, exclude_edges)
>>> result = gb.exclude_edges(subgraph, exclude_edges)
>>> print(result.node_pairs)
>>> print(result.node_pairs)
{
('A', '
relation
', 'B')
: (tensor([0]), tensor([0]))}
{
"A:
relation
:B"
: (tensor([0]), tensor([0]))}
>>> print(result.reverse_column_node_ids)
>>> print(result.reverse_column_node_ids)
{'B': tensor([10, 11, 12])}
{'B': tensor([10, 11, 12])}
>>> print(result.reverse_row_node_ids)
>>> print(result.reverse_row_node_ids)
{'A': tensor([13, 14, 15])}
{'A': tensor([13, 14, 15])}
>>> print(result.reverse_edge_ids)
>>> print(result.reverse_edge_ids)
{
('A', '
relation
', 'B')
: tensor([19])}
{
"A:
relation
:B"
: tensor([19])}
"""
"""
assert
isinstance
(
subgraph
.
node_pairs
,
tuple
)
==
isinstance
(
edges
,
tuple
),
(
assert
isinstance
(
subgraph
.
node_pairs
,
tuple
)
==
isinstance
(
edges
,
tuple
),
(
"The sampled subgraph and the edges to exclude should be both "
"The sampled subgraph and the edges to exclude should be both "
...
@@ -197,15 +194,16 @@ def exclude_edges(
...
@@ -197,15 +194,16 @@ def exclude_edges(
else
:
else
:
index
=
{}
index
=
{}
for
etype
,
pair
in
subgraph
.
node_pairs
.
items
():
for
etype
,
pair
in
subgraph
.
node_pairs
.
items
():
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
reverse_row_node_ids
=
(
reverse_row_node_ids
=
(
None
None
if
subgraph
.
reverse_row_node_ids
is
None
if
subgraph
.
reverse_row_node_ids
is
None
else
subgraph
.
reverse_row_node_ids
.
get
(
e
type
[
0
]
)
else
subgraph
.
reverse_row_node_ids
.
get
(
src_
type
)
)
)
reverse_column_node_ids
=
(
reverse_column_node_ids
=
(
None
None
if
subgraph
.
reverse_column_node_ids
is
None
if
subgraph
.
reverse_column_node_ids
is
None
else
subgraph
.
reverse_column_node_ids
.
get
(
e
type
[
2
]
)
else
subgraph
.
reverse_column_node_ids
.
get
(
dst_
type
)
)
)
reverse_edges
=
_to_reverse_ids
(
reverse_edges
=
_to_reverse_ids
(
pair
,
pair
,
...
...
python/dgl/graphbolt/sampled_subgraph.py
View file @
6b99f328
"""Graphbolt sampled subgraph."""
"""Graphbolt sampled subgraph."""
# pylint: disable= invalid-name
# pylint: disable= invalid-name
from
typing
import
Dict
,
Tuple
from
typing
import
Dict
,
Tuple
,
Union
import
torch
import
torch
...
@@ -14,7 +14,10 @@ class SampledSubgraph:
...
@@ -14,7 +14,10 @@ class SampledSubgraph:
@
property
@
property
def
node_pairs
(
def
node_pairs
(
self
,
self
,
)
->
Tuple
[
torch
.
Tensor
]
or
Dict
[(
str
,
str
,
str
),
Tuple
[
torch
.
Tensor
]]:
)
->
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
]:
"""Returns the node pairs representing source-destination edges.
"""Returns the node pairs representing source-destination edges.
- If `node_pairs` is a tuple: It should be in the format ('u', 'v')
- If `node_pairs` is a tuple: It should be in the format ('u', 'v')
representing source and destination pairs.
representing source and destination pairs.
...
@@ -26,7 +29,7 @@ class SampledSubgraph:
...
@@ -26,7 +29,7 @@ class SampledSubgraph:
@
property
@
property
def
reverse_column_node_ids
(
def
reverse_column_node_ids
(
self
,
self
,
)
->
torch
.
Tensor
or
Dict
[
str
,
torch
.
Tensor
]:
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]
]
:
"""Returns corresponding reverse column node ids the original graph.
"""Returns corresponding reverse column node ids the original graph.
Column's reverse node ids in the original graph. A graph structure
Column's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is
can be treated as a coordinated row and column pair, and this is
...
@@ -42,7 +45,9 @@ class SampledSubgraph:
...
@@ -42,7 +45,9 @@ class SampledSubgraph:
return
None
return
None
@
property
@
property
def
reverse_row_node_ids
(
self
)
->
torch
.
Tensor
or
Dict
[
str
,
torch
.
Tensor
]:
def
reverse_row_node_ids
(
self
,
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""Returns corresponding reverse row node ids the original graph.
"""Returns corresponding reverse row node ids the original graph.
Row's reverse node ids in the original graph. A graph structure
Row's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is
can be treated as a coordinated row and column pair, and this is
...
@@ -57,7 +62,7 @@ class SampledSubgraph:
...
@@ -57,7 +62,7 @@ class SampledSubgraph:
return
None
return
None
@
property
@
property
def
reverse_edge_ids
(
self
)
->
torch
.
Tensor
or
Dict
[
str
,
torch
.
Tensor
]:
def
reverse_edge_ids
(
self
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]
]
:
"""Returns corresponding reverse edge ids the original graph.
"""Returns corresponding reverse edge ids the original graph.
Reverse edge ids in the original graph. This is useful when edge
Reverse edge ids in the original graph. This is useful when edge
features are needed.
features are needed.
...
...
python/dgl/graphbolt/utils/sample_utils.py
View file @
6b99f328
...
@@ -5,6 +5,8 @@ from typing import Dict, List, Tuple, Union
...
@@ -5,6 +5,8 @@ from typing import Dict, List, Tuple, Union
import
torch
import
torch
from
..base
import
etype_str_to_tuple
def
unique_and_compact
(
def
unique_and_compact
(
nodes
:
Union
[
nodes
:
Union
[
...
@@ -61,7 +63,7 @@ def unique_and_compact(
...
@@ -61,7 +63,7 @@ def unique_and_compact(
def
unique_and_compact_node_pairs
(
def
unique_and_compact_node_pairs
(
node_pairs
:
Union
[
node_pairs
:
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
,
str
]
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
],
],
unique_dst_nodes
:
Union
[
unique_dst_nodes
:
Union
[
torch
.
Tensor
,
torch
.
Tensor
,
...
@@ -73,8 +75,8 @@ def unique_and_compact_node_pairs(
...
@@ -73,8 +75,8 @@ def unique_and_compact_node_pairs(
Parameters
Parameters
----------
----------
node_pairs : Tuple[torch.Tensor, torch.Tensor]
or
\
node_pairs :
Union[
Tuple[torch.Tensor, torch.Tensor]
,
Dict(
Tuple[str, str, str]
, Tuple[torch.Tensor, torch.Tensor])
Dict(
str
, Tuple[torch.Tensor, torch.Tensor])
]
Node pairs representing source-destination edges.
Node pairs representing source-destination edges.
- If `node_pairs` is a tuple: It means the graph is homogeneous.
- If `node_pairs` is a tuple: It means the graph is homogeneous.
Also, it should be in the format ('u', 'v') representing source
Also, it should be in the format ('u', 'v') representing source
...
@@ -102,20 +104,20 @@ def unique_and_compact_node_pairs(
...
@@ -102,20 +104,20 @@ def unique_and_compact_node_pairs(
>>> import dgl.graphbolt as gb
>>> import dgl.graphbolt as gb
>>> N1 = torch.LongTensor([1, 2, 2])
>>> N1 = torch.LongTensor([1, 2, 2])
>>> N2 = torch.LongTensor([5, 6, 5])
>>> N2 = torch.LongTensor([5, 6, 5])
>>> node_pairs = {
(
"n1
", "e1", "
n2"
)
: (N1, N2),
>>> node_pairs = {"n1
:e1:
n2": (N1, N2),
...
(
"n2
", "e2", "
n1"
)
: (N2, N1)}
... "n2
:e2:
n1": (N2, N1)}
>>> unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
>>> unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
... node_pairs
... node_pairs
... )
... )
>>> print(unique_nodes)
>>> print(unique_nodes)
{'n1': tensor([1, 2]), 'n2': tensor([5, 6])}
{'n1': tensor([1, 2]), 'n2': tensor([5, 6])}
>>> print(compacted_node_pairs)
>>> print(compacted_node_pairs)
{
('n1', 'e1', 'n2')
: (tensor([0, 1, 1]), tensor([0, 1, 0])),
{
"n1:e1:n2"
: (tensor([0, 1, 1]), tensor([0, 1, 0])),
('n2', 'e2', 'n1')
: (tensor([0, 1, 0]), tensor([0, 1, 1]))}
"n2:e2:n1"
: (tensor([0, 1, 0]), tensor([0, 1, 1]))}
"""
"""
is_homogeneous
=
not
isinstance
(
node_pairs
,
dict
)
is_homogeneous
=
not
isinstance
(
node_pairs
,
dict
)
if
is_homogeneous
:
if
is_homogeneous
:
node_pairs
=
{
(
"_N
"
,
"_E"
,
"
_N"
)
:
node_pairs
}
node_pairs
=
{
"_N
:_E:
_N"
:
node_pairs
}
if
unique_dst_nodes
is
not
None
:
if
unique_dst_nodes
is
not
None
:
assert
isinstance
(
assert
isinstance
(
unique_dst_nodes
,
torch
.
Tensor
unique_dst_nodes
,
torch
.
Tensor
...
@@ -126,8 +128,9 @@ def unique_and_compact_node_pairs(
...
@@ -126,8 +128,9 @@ def unique_and_compact_node_pairs(
src_nodes
=
defaultdict
(
list
)
src_nodes
=
defaultdict
(
list
)
dst_nodes
=
defaultdict
(
list
)
dst_nodes
=
defaultdict
(
list
)
for
etype
,
(
src_node
,
dst_node
)
in
node_pairs
.
items
():
for
etype
,
(
src_node
,
dst_node
)
in
node_pairs
.
items
():
src_nodes
[
etype
[
0
]].
append
(
src_node
)
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
dst_nodes
[
etype
[
2
]].
append
(
dst_node
)
src_nodes
[
src_type
].
append
(
src_node
)
dst_nodes
[
dst_type
].
append
(
dst_node
)
src_nodes
=
{
ntype
:
torch
.
cat
(
nodes
)
for
ntype
,
nodes
in
src_nodes
.
items
()}
src_nodes
=
{
ntype
:
torch
.
cat
(
nodes
)
for
ntype
,
nodes
in
src_nodes
.
items
()}
dst_nodes
=
{
ntype
:
torch
.
cat
(
nodes
)
for
ntype
,
nodes
in
dst_nodes
.
items
()}
dst_nodes
=
{
ntype
:
torch
.
cat
(
nodes
)
for
ntype
,
nodes
in
dst_nodes
.
items
()}
# Compute unique destination nodes if not provided.
# Compute unique destination nodes if not provided.
...
@@ -156,7 +159,7 @@ def unique_and_compact_node_pairs(
...
@@ -156,7 +159,7 @@ def unique_and_compact_node_pairs(
# Map back with the same order.
# Map back with the same order.
for
etype
,
pair
in
node_pairs
.
items
():
for
etype
,
pair
in
node_pairs
.
items
():
num_elem
=
pair
[
0
].
size
(
0
)
num_elem
=
pair
[
0
].
size
(
0
)
src_type
,
_
,
dst_type
=
etype
src_type
,
_
,
dst_type
=
etype
_str_to_tuple
(
etype
)
src
=
compacted_src
[
src_type
][:
num_elem
]
src
=
compacted_src
[
src_type
][:
num_elem
]
dst
=
compacted_dst
[
dst_type
][:
num_elem
]
dst
=
compacted_dst
[
dst_type
][:
num_elem
]
compacted_node_pairs
[
etype
]
=
(
src
,
dst
)
compacted_node_pairs
[
etype
]
=
(
src
,
dst
)
...
...
tests/python/pytorch/graphbolt/gb_test_utils.py
View file @
6b99f328
...
@@ -43,7 +43,7 @@ def get_metadata(num_ntypes, num_etypes):
...
@@ -43,7 +43,7 @@ def get_metadata(num_ntypes, num_etypes):
for
n2
in
range
(
n1
,
num_ntypes
):
for
n2
in
range
(
n1
,
num_ntypes
):
if
count
>=
num_etypes
:
if
count
>=
num_etypes
:
break
break
etypes
.
update
({
(
f
"n
{
n1
}
"
,
f
"
e
{
count
}
"
,
f
"
n
{
n2
}
"
)
:
count
})
etypes
.
update
({
f
"n
{
n1
}
:
e
{
count
}
:
n
{
n2
}
"
:
count
})
count
+=
1
count
+=
1
return
gb
.
GraphMetadata
(
ntypes
,
etypes
)
return
gb
.
GraphMetadata
(
ntypes
,
etypes
)
...
...
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
View file @
6b99f328
...
@@ -73,7 +73,7 @@ def test_hetero_empty_graph(num_nodes):
...
@@ -73,7 +73,7 @@ def test_hetero_empty_graph(num_nodes):
)
)
def
test_metadata_with_ntype_exception
(
ntypes
):
def
test_metadata_with_ntype_exception
(
ntypes
):
with
pytest
.
raises
(
Exception
):
with
pytest
.
raises
(
Exception
):
gb
.
GraphMetadata
(
ntypes
,
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
1
})
gb
.
GraphMetadata
(
ntypes
,
{
"n1
:e1:
n2"
:
1
})
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
...
@@ -87,9 +87,9 @@ def test_metadata_with_ntype_exception(ntypes):
...
@@ -87,9 +87,9 @@ def test_metadata_with_ntype_exception(ntypes):
{
"e1"
:
1
},
{
"e1"
:
1
},
{(
"n1"
,
"e1"
):
1
},
{(
"n1"
,
"e1"
):
1
},
{(
"n1"
,
"e1"
,
10
):
1
},
{(
"n1"
,
"e1"
,
10
):
1
},
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
1
,
(
"n1"
,
"e2"
,
"n3"
):
1
},
{
"n1
:e1:
n2"
:
1
,
(
"n1"
,
"e2"
,
"n3"
):
1
},
{(
"n1"
,
"e1"
,
"n10"
):
1
},
{(
"n1"
,
"e1"
,
"n10"
):
1
},
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
1.5
},
{
"n1
:e1:
n2"
:
1.5
},
],
],
)
)
def
test_metadata_with_etype_exception
(
etypes
):
def
test_metadata_with_etype_exception
(
etypes
):
...
@@ -320,10 +320,10 @@ def test_in_subgraph_heterogeneous():
...
@@ -320,10 +320,10 @@ def test_in_subgraph_heterogeneous():
"N1"
:
1
,
"N1"
:
1
,
}
}
etypes
=
{
etypes
=
{
(
"N0
"
,
"R0"
,
"
N0"
)
:
0
,
"N0
:R0:
N0"
:
0
,
(
"N0
"
,
"R1"
,
"
N1"
)
:
1
,
"N0
:R1:
N1"
:
1
,
(
"N1
"
,
"R2"
,
"
N0"
)
:
2
,
"N1
:R2:
N0"
:
2
,
(
"N1
"
,
"R3"
,
"
N1"
)
:
3
,
"N1
:R3:
N1"
:
3
,
}
}
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
...
@@ -403,8 +403,8 @@ def test_sample_neighbors_homo():
...
@@ -403,8 +403,8 @@ def test_sample_neighbors_homo():
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_sample_neighbors_hetero
(
labor
):
def
test_sample_neighbors_hetero
(
labor
):
"""Original graph in COO:
"""Original graph in COO:
(
"n1
", "e1", "
n2"
)
:[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n1
:e1:
n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
(
"n2
", "e2", "
n1"
)
:[0, 0, 1, 2], [0, 1, 1 ,0]
"n2
:e2:
n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 0 1
0 0 1 1 1
0 0 1 1 1
1 1 0 0 0
1 1 0 0 0
...
@@ -413,7 +413,7 @@ def test_sample_neighbors_hetero(labor):
...
@@ -413,7 +413,7 @@ def test_sample_neighbors_hetero(labor):
"""
"""
# Initialize data.
# Initialize data.
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
0
,
(
"n2
"
,
"e2"
,
"
n1"
)
:
1
}
etypes
=
{
"n1
:e1:
n2"
:
0
,
"n2
:e2:
n1"
:
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
num_nodes
=
5
num_nodes
=
5
num_edges
=
9
num_edges
=
9
...
@@ -441,11 +441,11 @@ def test_sample_neighbors_hetero(labor):
...
@@ -441,11 +441,11 @@ def test_sample_neighbors_hetero(labor):
# Verify in subgraph.
# Verify in subgraph.
expected_node_pairs
=
{
expected_node_pairs
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
(
"n1
:e1:
n2"
:
(
torch
.
LongTensor
([
0
,
1
]),
torch
.
LongTensor
([
0
,
1
]),
torch
.
LongTensor
([
0
,
0
]),
torch
.
LongTensor
([
0
,
0
]),
),
),
(
"n2
"
,
"e2"
,
"
n1"
)
:
(
"n2
:e2:
n1"
:
(
torch
.
LongTensor
([
0
,
2
]),
torch
.
LongTensor
([
0
,
2
]),
torch
.
LongTensor
([
0
,
0
]),
torch
.
LongTensor
([
0
,
0
]),
),
),
...
@@ -484,8 +484,8 @@ def test_sample_neighbors_fanouts(
...
@@ -484,8 +484,8 @@ def test_sample_neighbors_fanouts(
fanouts
,
expected_sampled_num1
,
expected_sampled_num2
,
labor
fanouts
,
expected_sampled_num1
,
expected_sampled_num2
,
labor
):
):
"""Original graph in COO:
"""Original graph in COO:
(
"n1
", "e1", "
n2"
)
:[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n1
:e1:
n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
(
"n2
", "e2", "
n1"
)
:[0, 0, 1, 2], [0, 1, 1 ,0]
"n2
:e2:
n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 0 1
0 0 1 1 1
0 0 1 1 1
1 1 0 0 0
1 1 0 0 0
...
@@ -494,7 +494,7 @@ def test_sample_neighbors_fanouts(
...
@@ -494,7 +494,7 @@ def test_sample_neighbors_fanouts(
"""
"""
# Initialize data.
# Initialize data.
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
0
,
(
"n2
"
,
"e2"
,
"
n1"
)
:
1
}
etypes
=
{
"n1
:e1:
n2"
:
0
,
"n2
:e2:
n1"
:
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
num_nodes
=
5
num_nodes
=
5
num_edges
=
9
num_edges
=
9
...
@@ -520,14 +520,8 @@ def test_sample_neighbors_fanouts(
...
@@ -520,14 +520,8 @@ def test_sample_neighbors_fanouts(
subgraph
=
sampler
(
nodes
,
fanouts
)
subgraph
=
sampler
(
nodes
,
fanouts
)
# Verify in subgraph.
# Verify in subgraph.
assert
(
assert
subgraph
.
node_pairs
[
"n1:e1:n2"
][
0
].
numel
()
==
expected_sampled_num1
subgraph
.
node_pairs
[(
"n1"
,
"e1"
,
"n2"
)][
0
].
numel
()
assert
subgraph
.
node_pairs
[
"n2:e2:n1"
][
0
].
numel
()
==
expected_sampled_num2
==
expected_sampled_num1
)
assert
(
subgraph
.
node_pairs
[(
"n2"
,
"e2"
,
"n1"
)][
0
].
numel
()
==
expected_sampled_num2
)
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
...
@@ -542,8 +536,8 @@ def test_sample_neighbors_replace(
...
@@ -542,8 +536,8 @@ def test_sample_neighbors_replace(
replace
,
expected_sampled_num1
,
expected_sampled_num2
replace
,
expected_sampled_num1
,
expected_sampled_num2
):
):
"""Original graph in COO:
"""Original graph in COO:
(
"n1
", "e1", "
n2"
)
:[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n1
:e1:
n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
(
"n2
", "e2", "
n1"
)
:[0, 0, 1, 2], [0, 1, 1 ,0]
"n2
:e2:
n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 0 1
0 0 1 1 1
0 0 1 1 1
1 1 0 0 0
1 1 0 0 0
...
@@ -552,7 +546,7 @@ def test_sample_neighbors_replace(
...
@@ -552,7 +546,7 @@ def test_sample_neighbors_replace(
"""
"""
# Initialize data.
# Initialize data.
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
0
,
(
"n2
"
,
"e2"
,
"
n1"
)
:
1
}
etypes
=
{
"n1
:e1:
n2"
:
0
,
"n2
:e2:
n1"
:
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
num_nodes
=
5
num_nodes
=
5
num_edges
=
9
num_edges
=
9
...
@@ -578,14 +572,8 @@ def test_sample_neighbors_replace(
...
@@ -578,14 +572,8 @@ def test_sample_neighbors_replace(
)
)
# Verify in subgraph.
# Verify in subgraph.
assert
(
assert
subgraph
.
node_pairs
[
"n1:e1:n2"
][
0
].
numel
()
==
expected_sampled_num1
subgraph
.
node_pairs
[(
"n1"
,
"e1"
,
"n2"
)][
0
].
numel
()
assert
subgraph
.
node_pairs
[
"n2:e2:n1"
][
0
].
numel
()
==
expected_sampled_num2
==
expected_sampled_num1
)
assert
(
subgraph
.
node_pairs
[(
"n2"
,
"e2"
,
"n1"
)][
0
].
numel
()
==
expected_sampled_num2
)
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
...
@@ -811,7 +799,7 @@ def test_from_dglgraph_homogeneous():
...
@@ -811,7 +799,7 @@ def test_from_dglgraph_homogeneous():
assert
torch
.
equal
(
gb_g
.
node_type_offset
,
torch
.
tensor
([
0
,
1000
]))
assert
torch
.
equal
(
gb_g
.
node_type_offset
,
torch
.
tensor
([
0
,
1000
]))
assert
torch
.
all
(
gb_g
.
type_per_edge
==
0
)
assert
torch
.
all
(
gb_g
.
type_per_edge
==
0
)
assert
gb_g
.
metadata
.
node_type_to_id
==
{
"_N"
:
0
}
assert
gb_g
.
metadata
.
node_type_to_id
==
{
"_N"
:
0
}
assert
gb_g
.
metadata
.
edge_type_to_id
==
{
(
"_N
"
,
"_E"
,
"
_N"
)
:
0
}
assert
gb_g
.
metadata
.
edge_type_to_id
==
{
"_N
:_E:
_N"
:
0
}
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
...
@@ -855,10 +843,10 @@ def test_from_dglgraph_heterogeneous():
...
@@ -855,10 +843,10 @@ def test_from_dglgraph_heterogeneous():
"n3"
:
2
,
"n3"
:
2
,
}
}
assert
gb_g
.
metadata
.
edge_type_to_id
==
{
assert
gb_g
.
metadata
.
edge_type_to_id
==
{
(
"n1
"
,
"r12"
,
"
n2"
)
:
0
,
"n1
:r12:
n2"
:
0
,
(
"n1
"
,
"r13"
,
"
n3"
)
:
1
,
"n1
:r13:
n3"
:
1
,
(
"n2
"
,
"r21"
,
"
n1"
)
:
2
,
"n2
:r21:
n1"
:
2
,
(
"n2
"
,
"r23"
,
"
n3"
)
:
3
,
"n2
:r23:
n3"
:
3
,
}
}
...
@@ -972,9 +960,9 @@ def test_sample_neighbors_hetero_pick_number(
...
@@ -972,9 +960,9 @@ def test_sample_neighbors_hetero_pick_number(
num_edges
=
9
num_edges
=
9
ntypes
=
{
"N0"
:
0
,
"N1"
:
1
,
"N2"
:
2
,
"N3"
:
3
}
ntypes
=
{
"N0"
:
0
,
"N1"
:
1
,
"N2"
:
2
,
"N3"
:
3
}
etypes
=
{
etypes
=
{
(
"N0
"
,
"R0"
,
"
N1"
)
:
0
,
"N0
:R0:
N1"
:
0
,
(
"N0
"
,
"R1"
,
"
N2"
)
:
1
,
"N0
:R1:
N2"
:
1
,
(
"N0
"
,
"R2"
,
"
N3"
)
:
2
,
"N0
:R2:
N3"
:
2
,
}
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
indptr
=
torch
.
LongTensor
([
0
,
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
])
indptr
=
torch
.
LongTensor
([
0
,
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
])
...
...
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
View file @
6b99f328
...
@@ -151,7 +151,7 @@ def get_hetero_graph():
...
@@ -151,7 +151,7 @@ def get_hetero_graph():
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
0
,
(
"n2
"
,
"e2"
,
"
n1"
)
:
1
}
etypes
=
{
"n1
:e1:
n2"
:
0
,
"n2
:e2:
n1"
:
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
indptr
=
torch
.
LongTensor
([
0
,
2
,
4
,
6
,
8
,
10
])
indptr
=
torch
.
LongTensor
([
0
,
2
,
4
,
6
,
8
,
10
])
indices
=
torch
.
LongTensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
0
,
1
])
indices
=
torch
.
LongTensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
0
,
1
])
...
...
tests/python/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
View file @
6b99f328
...
@@ -73,7 +73,7 @@ def test_exclude_edges_homo(reverse_row, reverse_column):
...
@@ -73,7 +73,7 @@ def test_exclude_edges_homo(reverse_row, reverse_column):
@
pytest
.
mark
.
parametrize
(
"reverse_column"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"reverse_column"
,
[
True
,
False
])
def
test_exclude_edges_hetero
(
reverse_row
,
reverse_column
):
def
test_exclude_edges_hetero
(
reverse_row
,
reverse_column
):
node_pairs
=
{
node_pairs
=
{
(
"A
"
,
"
relation
"
,
"
B"
)
:
(
"A
:
relation
:
B"
:
(
torch
.
tensor
([
0
,
1
,
2
]),
torch
.
tensor
([
0
,
1
,
2
]),
torch
.
tensor
([
2
,
1
,
0
]),
torch
.
tensor
([
2
,
1
,
0
]),
)
)
...
@@ -94,7 +94,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
...
@@ -94,7 +94,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
else
:
else
:
reverse_column_node_ids
=
None
reverse_column_node_ids
=
None
dst_to_exclude
=
torch
.
tensor
([
0
,
2
])
dst_to_exclude
=
torch
.
tensor
([
0
,
2
])
reverse_edge_ids
=
{
(
"A
"
,
"
relation
"
,
"
B"
)
:
torch
.
tensor
([
19
,
20
,
21
])}
reverse_edge_ids
=
{
"A
:
relation
:
B"
:
torch
.
tensor
([
19
,
20
,
21
])}
subgraph
=
SampledSubgraphImpl
(
subgraph
=
SampledSubgraphImpl
(
node_pairs
=
node_pairs
,
node_pairs
=
node_pairs
,
reverse_column_node_ids
=
reverse_column_node_ids
,
reverse_column_node_ids
=
reverse_column_node_ids
,
...
@@ -103,14 +103,14 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
...
@@ -103,14 +103,14 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
)
)
edges_to_exclude
=
{
edges_to_exclude
=
{
(
"A
"
,
"
relation
"
,
"
B"
)
:
(
"A
:
relation
:
B"
:
(
src_to_exclude
,
src_to_exclude
,
dst_to_exclude
,
dst_to_exclude
,
)
)
}
}
result
=
exclude_edges
(
subgraph
,
edges_to_exclude
)
result
=
exclude_edges
(
subgraph
,
edges_to_exclude
)
expected_node_pairs
=
{
expected_node_pairs
=
{
(
"A
"
,
"
relation
"
,
"
B"
)
:
(
"A
:
relation
:
B"
:
(
torch
.
tensor
([
1
]),
torch
.
tensor
([
1
]),
torch
.
tensor
([
1
]),
torch
.
tensor
([
1
]),
)
)
...
@@ -127,7 +127,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
...
@@ -127,7 +127,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
}
}
else
:
else
:
expected_column_node_ids
=
None
expected_column_node_ids
=
None
expected_edge_ids
=
{
(
"A
"
,
"
relation
"
,
"
B"
)
:
torch
.
tensor
([
20
])}
expected_edge_ids
=
{
"A
:
relation
:
B"
:
torch
.
tensor
([
20
])}
_assert_container_equal
(
result
.
node_pairs
,
expected_node_pairs
)
_assert_container_equal
(
result
.
node_pairs
,
expected_node_pairs
)
_assert_container_equal
(
_assert_container_equal
(
...
...
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
6b99f328
...
@@ -71,7 +71,7 @@ def get_hetero_graph():
...
@@ -71,7 +71,7 @@ def get_hetero_graph():
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
0
,
(
"n2
"
,
"e2"
,
"
n1"
)
:
1
}
etypes
=
{
"n1
:e1:
n2"
:
0
,
"n2
:e2:
n1"
:
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
indptr
=
torch
.
LongTensor
([
0
,
2
,
4
,
6
,
8
,
10
])
indptr
=
torch
.
LongTensor
([
0
,
2
,
4
,
6
,
8
,
10
])
indices
=
torch
.
LongTensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
0
,
1
])
indices
=
torch
.
LongTensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
0
,
1
])
...
@@ -120,8 +120,8 @@ def test_FeatureFetcher_with_edges_hetero():
...
@@ -120,8 +120,8 @@ def test_FeatureFetcher_with_edges_hetero():
def
add_node_and_edge_ids
(
seeds
):
def
add_node_and_edge_ids
(
seeds
):
subgraphs
=
[]
subgraphs
=
[]
reverse_edge_ids
=
{
reverse_edge_ids
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
torch
.
randint
(
0
,
50
,
(
10
,)),
"n1
:e1:
n2"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
(
"n2
"
,
"e2"
,
"
n1"
)
:
torch
.
randint
(
0
,
50
,
(
10
,)),
"n2
:e2:
n1"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
}
}
for
_
in
range
(
3
):
for
_
in
range
(
3
):
subgraphs
.
append
(
subgraphs
.
append
(
...
...
tests/python/pytorch/graphbolt/test_graphbolt_utils.py
View file @
6b99f328
...
@@ -62,15 +62,15 @@ def test_unique_and_compact_node_pairs_hetero():
...
@@ -62,15 +62,15 @@ def test_unique_and_compact_node_pairs_hetero():
"n3"
:
unique_N3
,
"n3"
:
unique_N3
,
}
}
node_pairs
=
{
node_pairs
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
(
"n1
:e1:
n2"
:
(
N1
[:
20
],
N1
[:
20
],
N2
,
N2
,
),
),
(
"n1
"
,
"e2"
,
"
n3"
)
:
(
"n1
:e2:
n3"
:
(
N1
[
20
:
30
],
N1
[
20
:
30
],
N3
,
N3
,
),
),
(
"n2
"
,
"e3"
,
"
n3"
)
:
(
"n2
:e3:
n3"
:
(
N2
[
10
:],
N2
[
10
:],
N3
,
N3
,
),
),
...
@@ -84,7 +84,7 @@ def test_unique_and_compact_node_pairs_hetero():
...
@@ -84,7 +84,7 @@ def test_unique_and_compact_node_pairs_hetero():
assert
torch
.
equal
(
torch
.
sort
(
nodes
)[
0
],
expected_nodes
)
assert
torch
.
equal
(
torch
.
sort
(
nodes
)[
0
],
expected_nodes
)
for
etype
,
pair
in
compacted_node_pairs
.
items
():
for
etype
,
pair
in
compacted_node_pairs
.
items
():
u
,
v
=
pair
u
,
v
=
pair
u_type
,
_
,
v_type
=
etype
u_type
,
_
,
v_type
=
gb
.
etype
_str_to_tuple
(
etype
)
u
,
v
=
unique_nodes
[
u_type
][
u
],
unique_nodes
[
v_type
][
v
]
u
,
v
=
unique_nodes
[
u_type
][
u
],
unique_nodes
[
v_type
][
v
]
expected_u
,
expected_v
=
node_pairs
[
etype
]
expected_u
,
expected_v
=
node_pairs
[
etype
]
assert
torch
.
equal
(
u
,
expected_u
)
assert
torch
.
equal
(
u
,
expected_u
)
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
6b99f328
...
@@ -84,7 +84,7 @@ def get_hetero_graph():
...
@@ -84,7 +84,7 @@ def get_hetero_graph():
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{
(
"n1
"
,
"e1"
,
"
n2"
)
:
0
,
(
"n2
"
,
"e2"
,
"
n1"
)
:
1
}
etypes
=
{
"n1
:e1:
n2"
:
0
,
"n2
:e2:
n1"
:
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
indptr
=
torch
.
LongTensor
([
0
,
2
,
4
,
6
,
8
,
10
])
indptr
=
torch
.
LongTensor
([
0
,
2
,
4
,
6
,
8
,
10
])
indices
=
torch
.
LongTensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
0
,
1
])
indices
=
torch
.
LongTensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
0
,
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment