Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ed3840fc
Unverified
Commit
ed3840fc
authored
Sep 22, 2023
by
peizhou001
Committed by
GitHub
Sep 22, 2023
Browse files
[Graphbolt] Replace reverse with original (#6371)
parent
7e49ccef
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
168 additions
and
165 deletions
+168
-165
examples/sampling/graphbolt/link_prediction.py
examples/sampling/graphbolt/link_prediction.py
+2
-2
graphbolt/include/graphbolt/sampled_subgraph.h
graphbolt/include/graphbolt/sampled_subgraph.h
+20
-20
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+4
-3
python/dgl/graphbolt/feature_fetcher.py
python/dgl/graphbolt/feature_fetcher.py
+3
-3
python/dgl/graphbolt/impl/csc_sampling_graph.py
python/dgl/graphbolt/impl/csc_sampling_graph.py
+1
-1
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+3
-3
python/dgl/graphbolt/impl/sampled_subgraph_impl.py
python/dgl/graphbolt/impl/sampled_subgraph_impl.py
+14
-12
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+16
-16
python/dgl/graphbolt/sampled_subgraph.py
python/dgl/graphbolt/sampled_subgraph.py
+36
-36
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
.../python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
+15
-15
tests/python/pytorch/graphbolt/impl/test_minibatch.py
tests/python/pytorch/graphbolt/impl/test_minibatch.py
+29
-29
tests/python/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
...thon/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
+22
-22
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+3
-3
No files found.
examples/sampling/graphbolt/link_prediction.py
View file @
ed3840fc
...
...
@@ -206,8 +206,8 @@ def to_dgl_blocks(sampled_subgraphs: gb.SampledSubgraphImpl):
blocks
=
[
dgl
.
create_block
(
sampled_subgraph
.
node_pairs
,
num_src_nodes
=
sampled_subgraph
.
reverse
_row_node_ids
.
shape
[
0
],
num_dst_nodes
=
sampled_subgraph
.
reverse
_column_node_ids
.
shape
[
0
],
num_src_nodes
=
sampled_subgraph
.
original
_row_node_ids
.
shape
[
0
],
num_dst_nodes
=
sampled_subgraph
.
original
_column_node_ids
.
shape
[
0
],
)
for
sampled_subgraph
in
sampled_subgraphs
]
...
...
graphbolt/include/graphbolt/sampled_subgraph.h
View file @
ed3840fc
...
...
@@ -22,17 +22,17 @@ namespace sampling {
* ```
* auto indptr = torch::tensor({0, 2, 3, 4}, {torch::kInt64});
* auto indices = torch::tensor({55, 101, 3, 3}, {torch::kInt64});
* auto
reverse
_column_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64});
* auto
original
_column_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64});
*
* SampledSubgraph sampledSubgraph(indptr, indices,
reverse
_column_node_ids);
* SampledSubgraph sampledSubgraph(indptr, indices,
original
_column_node_ids);
* ```
*
* The `
reverse
_column_node_ids` indicates that nodes `[3, 3, 101]` in the
* The `
original
_column_node_ids` indicates that nodes `[3, 3, 101]` in the
* original graph are mapped to `[0, 1, 2]` in this subgraph, and because
* `
reverse
_row_node_ids` is `Null`, `{55, 101, 3, 3}` in `indices` is just
* `
original
_row_node_ids` is `Null`, `{55, 101, 3, 3}` in `indices` is just
* the original node ids without compaction.
*
* If `
reverse
_row_node_ids = torch::tensor({55, 101, 3}, {torch::kInt64})`,
* If `
original
_row_node_ids = torch::tensor({55, 101, 3}, {torch::kInt64})`,
* it would indicate a different mapping for the row nodes. Note this is
* inconsistent with column, which is legal, as `3` is mapped to `0` and `1` in
* the column while `2` in the row.
...
...
@@ -44,24 +44,24 @@ struct SampledSubgraph : torch::CustomClassHolder {
*
* @param indptr CSC format index pointer array.
* @param indices CSC format index array.
* @param
reverse
_column_node_ids Row's reverse node ids in the original
* @param
original
_column_node_ids Row's reverse node ids in the original
* graph.
* @param
reverse
_row_node_ids Column's reverse node ids in the original
* @param
original
_row_node_ids Column's reverse node ids in the original
* graph.
* @param
reverse
_edge_ids Reverse edge ids in the original graph.
* @param
original
_edge_ids Reverse edge ids in the original graph.
* @param type_per_edge Type id of each edge.
*/
SampledSubgraph
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
reverse
_column_node_ids
,
torch
::
optional
<
torch
::
Tensor
>
reverse
_row_node_ids
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
reverse
_edge_ids
=
torch
::
nullopt
,
torch
::
Tensor
original
_column_node_ids
,
torch
::
optional
<
torch
::
Tensor
>
original
_row_node_ids
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
original
_edge_ids
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
=
torch
::
nullopt
)
:
indptr
(
indptr
),
indices
(
indices
),
reverse
_column_node_ids
(
reverse
_column_node_ids
),
reverse
_row_node_ids
(
reverse
_row_node_ids
),
reverse
_edge_ids
(
reverse
_edge_ids
),
original
_column_node_ids
(
original
_column_node_ids
),
original
_row_node_ids
(
original
_row_node_ids
),
original
_edge_ids
(
original
_edge_ids
),
type_per_edge
(
type_per_edge
)
{}
SampledSubgraph
()
=
default
;
...
...
@@ -69,14 +69,14 @@ struct SampledSubgraph : torch::CustomClassHolder {
/**
* @brief CSC format index pointer array, where the implicit node ids are
* already compacted. And the original ids are stored in the
* `
reverse
_column_node_ids` field.
* `
original
_column_node_ids` field.
*/
torch
::
Tensor
indptr
;
/**
* @brief CSC format index array, where the node ids can be compacted ids or
* original ids. If compacted, the original ids are stored in the
* `
reverse
_row_node_ids` field.
* `
original
_row_node_ids` field.
*/
torch
::
Tensor
indices
;
...
...
@@ -88,7 +88,7 @@ struct SampledSubgraph : torch::CustomClassHolder {
* @note This is required and the mapping relations can be inconsistent with
* column's.
*/
torch
::
Tensor
reverse
_column_node_ids
;
torch
::
Tensor
original
_column_node_ids
;
/**
* @brief Row's reverse node ids in the original graph. A graph structure
...
...
@@ -98,14 +98,14 @@ struct SampledSubgraph : torch::CustomClassHolder {
* @note This is optional and the mapping relations can be inconsistent with
* row's.
*/
torch
::
optional
<
torch
::
Tensor
>
reverse
_row_node_ids
;
torch
::
optional
<
torch
::
Tensor
>
original
_row_node_ids
;
/**
* @brief Reverse edge ids in the original graph, the edge with id
* `
reverse
_edge_ids[i]` in the original graph is mapped to `i` in this
* `
original
_edge_ids[i]` in the original graph is mapped to `i` in this
* subgraph. This is useful when edge features are needed.
*/
torch
::
optional
<
torch
::
Tensor
>
reverse
_edge_ids
;
torch
::
optional
<
torch
::
Tensor
>
original
_edge_ids
;
/**
* @brief Type id of each edge, where type id is the corresponding index of
...
...
graphbolt/src/python_binding.cc
View file @
ed3840fc
...
...
@@ -17,10 +17,11 @@ TORCH_LIBRARY(graphbolt, m) {
.
def_readwrite
(
"indptr"
,
&
SampledSubgraph
::
indptr
)
.
def_readwrite
(
"indices"
,
&
SampledSubgraph
::
indices
)
.
def_readwrite
(
"
reverse
_row_node_ids"
,
&
SampledSubgraph
::
reverse
_row_node_ids
)
"
original
_row_node_ids"
,
&
SampledSubgraph
::
original
_row_node_ids
)
.
def_readwrite
(
"reverse_column_node_ids"
,
&
SampledSubgraph
::
reverse_column_node_ids
)
.
def_readwrite
(
"reverse_edge_ids"
,
&
SampledSubgraph
::
reverse_edge_ids
)
"original_column_node_ids"
,
&
SampledSubgraph
::
original_column_node_ids
)
.
def_readwrite
(
"original_edge_ids"
,
&
SampledSubgraph
::
original_edge_ids
)
.
def_readwrite
(
"type_per_edge"
,
&
SampledSubgraph
::
type_per_edge
);
m
.
class_
<
CSCSamplingGraph
>
(
"CSCSamplingGraph"
)
.
def
(
"num_nodes"
,
&
CSCSamplingGraph
::
NumNodes
)
...
...
python/dgl/graphbolt/feature_fetcher.py
View file @
ed3840fc
...
...
@@ -95,14 +95,14 @@ class FeatureFetcher(MiniBatchTransformer):
# Read Edge features.
if
self
.
edge_feature_keys
and
data
.
sampled_subgraphs
:
for
i
,
subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
if
subgraph
.
reverse
_edge_ids
is
None
:
if
subgraph
.
original
_edge_ids
is
None
:
continue
if
is_heterogeneous
:
for
(
type_name
,
feature_names
,
)
in
self
.
edge_feature_keys
.
items
():
edges
=
subgraph
.
reverse
_edge_ids
.
get
(
type_name
,
None
)
edges
=
subgraph
.
original
_edge_ids
.
get
(
type_name
,
None
)
if
edges
is
None
:
continue
for
feature_name
in
feature_names
:
...
...
@@ -119,6 +119,6 @@ class FeatureFetcher(MiniBatchTransformer):
"edge"
,
None
,
feature_name
,
subgraph
.
reverse
_edge_ids
,
subgraph
.
original
_edge_ids
,
)
return
data
python/dgl/graphbolt/impl/csc_sampling_graph.py
View file @
ed3840fc
...
...
@@ -225,7 +225,7 @@ class CSCSamplingGraph:
column_num
=
(
C_sampled_subgraph
.
indptr
[
1
:]
-
C_sampled_subgraph
.
indptr
[:
-
1
]
)
column
=
C_sampled_subgraph
.
reverse
_column_node_ids
.
repeat_interleave
(
column
=
C_sampled_subgraph
.
original
_column_node_ids
.
repeat_interleave
(
column_num
)
row
=
C_sampled_subgraph
.
indices
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
ed3840fc
...
...
@@ -106,14 +106,14 @@ class NeighborSampler(SubgraphSampler):
self
.
replace
,
self
.
prob_name
,
)
reverse
_column_node_ids
=
seeds
original
_column_node_ids
=
seeds
seeds
,
compacted_node_pairs
=
unique_and_compact_node_pairs
(
subgraph
.
node_pairs
,
seeds
)
subgraph
=
SampledSubgraphImpl
(
node_pairs
=
compacted_node_pairs
,
reverse
_column_node_ids
=
reverse
_column_node_ids
,
reverse
_row_node_ids
=
seeds
,
original
_column_node_ids
=
original
_column_node_ids
,
original
_row_node_ids
=
seeds
,
)
subgraphs
.
insert
(
0
,
subgraph
)
return
seeds
,
subgraphs
...
...
python/dgl/graphbolt/impl/sampled_subgraph_impl.py
View file @
ed3840fc
...
...
@@ -17,31 +17,33 @@ class SampledSubgraphImpl(SampledSubgraph):
--------
>>> node_pairs = {"A:relation:B"): (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))}
>>>
reverse
_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>>
reverse
_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>>
reverse
_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>>
original
_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>>
original
_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>>
original
_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
...
reverse
_column_node_ids=
reverse
_column_node_ids,
...
reverse
_row_node_ids=
reverse
_row_node_ids,
...
reverse
_edge_ids=
reverse
_edge_ids
...
original
_column_node_ids=
original
_column_node_ids,
...
original
_row_node_ids=
original
_row_node_ids,
...
original
_edge_ids=
original
_edge_ids
... )
>>> print(subgraph.node_pairs)
{"A:relation:B": (tensor([0, 1, 2]), tensor([0, 1, 2]))}
>>> print(subgraph.
reverse
_column_node_ids)
>>> print(subgraph.
original
_column_node_ids)
{'B': tensor([10, 11, 12])}
>>> print(subgraph.
reverse
_row_node_ids)
>>> print(subgraph.
original
_row_node_ids)
{'A': tensor([13, 14, 15])}
>>> print(subgraph.
reverse
_edge_ids)
>>> print(subgraph.
original
_edge_ids)
{"A:relation:B": tensor([19, 20, 21])}
"""
node_pairs
:
Union
[
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
]
=
None
reverse_column_node_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
reverse_row_node_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
reverse_edge_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
original_column_node_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
original_row_node_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
original_edge_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
if
isinstance
(
self
.
node_pairs
,
dict
):
...
...
python/dgl/graphbolt/minibatch.py
View file @
ed3840fc
...
...
@@ -140,14 +140,14 @@ class MiniBatch:
blocks
=
[]
for
subgraph
in
self
.
sampled_subgraphs
:
reverse
_row_node_ids
=
subgraph
.
reverse
_row_node_ids
original
_row_node_ids
=
subgraph
.
original
_row_node_ids
assert
(
reverse
_row_node_ids
is
not
None
),
"Missing `
reverse
_row_node_ids` in sampled subgraph."
reverse
_column_node_ids
=
subgraph
.
reverse
_column_node_ids
original
_row_node_ids
is
not
None
),
"Missing `
original
_row_node_ids` in sampled subgraph."
original
_column_node_ids
=
subgraph
.
original
_column_node_ids
assert
(
reverse
_column_node_ids
is
not
None
),
"Missing `
reverse
_column_node_ids` in sampled subgraph."
original
_column_node_ids
is
not
None
),
"Missing `
original
_column_node_ids` in sampled subgraph."
if
is_heterogeneous
:
node_pairs
=
{
etype_str_to_tuple
(
etype
):
v
...
...
@@ -155,16 +155,16 @@ class MiniBatch:
}
num_src_nodes
=
{
ntype
:
nodes
.
size
(
0
)
for
ntype
,
nodes
in
reverse
_row_node_ids
.
items
()
for
ntype
,
nodes
in
original
_row_node_ids
.
items
()
}
num_dst_nodes
=
{
ntype
:
nodes
.
size
(
0
)
for
ntype
,
nodes
in
reverse
_column_node_ids
.
items
()
for
ntype
,
nodes
in
original
_column_node_ids
.
items
()
}
else
:
node_pairs
=
subgraph
.
node_pairs
num_src_nodes
=
reverse
_row_node_ids
.
size
(
0
)
num_dst_nodes
=
reverse
_column_node_ids
.
size
(
0
)
num_src_nodes
=
original
_row_node_ids
.
size
(
0
)
num_dst_nodes
=
original
_column_node_ids
.
size
(
0
)
blocks
.
append
(
dgl
.
create_block
(
node_pairs
,
...
...
@@ -194,15 +194,15 @@ class MiniBatch:
# Assign reverse node ids to the outermost layer's source nodes.
for
node_type
,
reverse_ids
in
self
.
sampled_subgraphs
[
0
].
reverse
_row_node_ids
.
items
():
].
original
_row_node_ids
.
items
():
blocks
[
0
].
srcnodes
[
node_type
].
data
[
dgl
.
NID
]
=
reverse_ids
# Assign reverse edges ids.
for
block
,
subgraph
in
zip
(
blocks
,
self
.
sampled_subgraphs
):
if
subgraph
.
reverse
_edge_ids
:
if
subgraph
.
original
_edge_ids
:
for
(
edge_type
,
reverse_ids
,
)
in
subgraph
.
reverse
_edge_ids
.
items
():
)
in
subgraph
.
original
_edge_ids
.
items
():
block
.
edges
[
etype_str_to_tuple
(
edge_type
)].
data
[
dgl
.
EID
]
=
reverse_ids
...
...
@@ -218,11 +218,11 @@ class MiniBatch:
block
.
edata
[
feature_name
]
=
feature
blocks
[
0
].
srcdata
[
dgl
.
NID
]
=
self
.
sampled_subgraphs
[
0
].
reverse
_row_node_ids
].
original
_row_node_ids
# Assign reverse edges ids.
for
block
,
subgraph
in
zip
(
blocks
,
self
.
sampled_subgraphs
):
if
subgraph
.
reverse
_edge_ids
is
not
None
:
block
.
edata
[
dgl
.
EID
]
=
subgraph
.
reverse
_edge_ids
if
subgraph
.
original
_edge_ids
is
not
None
:
block
.
edata
[
dgl
.
EID
]
=
subgraph
.
original
_edge_ids
return
blocks
...
...
python/dgl/graphbolt/sampled_subgraph.py
View file @
ed3840fc
...
...
@@ -29,16 +29,16 @@ class SampledSubgraph:
raise
NotImplementedError
@
property
def
reverse
_column_node_ids
(
def
original
_column_node_ids
(
self
,
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""Returns corresponding reverse column node ids the original graph.
Column's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is
the mapped ids of the column.
- If `
reverse
_column_node_ids` is a tensor: It represents the
- If `
original
_column_node_ids` is a tensor: It represents the
original node ids.
- If `
reverse
_column_node_ids` is a dictionary: The keys should be
- If `
original
_column_node_ids` is a dictionary: The keys should be
node type and the values should be corresponding original
heterogeneous node ids.
If present, it means column IDs are compacted, and `node_pairs`
...
...
@@ -47,16 +47,16 @@ class SampledSubgraph:
return
None
@
property
def
reverse
_row_node_ids
(
def
original
_row_node_ids
(
self
,
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""Returns corresponding reverse row node ids the original graph.
Row's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is
the mapped ids of the row.
- If `
reverse
_row_node_ids` is a tensor: It represents the
- If `
original
_row_node_ids` is a tensor: It represents the
original node ids.
- If `
reverse
_row_node_ids` is a dictionary: The keys should be
- If `
original
_row_node_ids` is a dictionary: The keys should be
node type and the values should be corresponding original
heterogeneous node ids.
If present, it means row IDs are compacted, and `node_pairs`
...
...
@@ -64,13 +64,13 @@ class SampledSubgraph:
return
None
@
property
def
reverse
_edge_ids
(
self
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
def
original
_edge_ids
(
self
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""Returns corresponding reverse edge ids the original graph.
Reverse edge ids in the original graph. This is useful when edge
features are needed.
- If `
reverse
_edge_ids` is a tensor: It represents the
- If `
original
_edge_ids` is a tensor: It represents the
original edge ids.
- If `
reverse
_edge_ids` is a dictionary: The keys should be
- If `
original
_edge_ids` is a dictionary: The keys should be
edge type and the values should be corresponding original
heterogeneous edge ids.
"""
...
...
@@ -110,24 +110,24 @@ class SampledSubgraph:
--------
>>> node_pairs = {"A:relation:B": (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))}
>>>
reverse
_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>>
reverse
_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>>
reverse
_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>>
original
_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>>
original
_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>>
original
_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
...
reverse
_column_node_ids=
reverse
_column_node_ids,
...
reverse
_row_node_ids=
reverse
_row_node_ids,
...
reverse
_edge_ids=
reverse
_edge_ids
...
original
_column_node_ids=
original
_column_node_ids,
...
original
_row_node_ids=
original
_row_node_ids,
...
original
_edge_ids=
original
_edge_ids
... )
>>> edges_to_exclude = (torch.tensor([14, 15]), torch.tensor([11, 12]))
>>> result = subgraph.exclude_edges(edges_to_exclude)
>>> print(result.node_pairs)
{"A:relation:B": (tensor([0]), tensor([0]))}
>>> print(result.
reverse
_column_node_ids)
>>> print(result.
original
_column_node_ids)
{'B': tensor([10, 11, 12])}
>>> print(result.
reverse
_row_node_ids)
>>> print(result.
original
_row_node_ids)
{'A': tensor([13, 14, 15])}
>>> print(result.
reverse
_edge_ids)
>>> print(result.
original
_edge_ids)
{"A:relation:B": tensor([19])}
"""
assert
isinstance
(
self
.
node_pairs
,
tuple
)
==
isinstance
(
edges
,
tuple
),
(
...
...
@@ -144,8 +144,8 @@ class SampledSubgraph:
if
isinstance
(
self
.
node_pairs
,
tuple
):
reverse_edges
=
_to_reverse_ids
(
self
.
node_pairs
,
self
.
reverse
_row_node_ids
,
self
.
reverse
_column_node_ids
,
self
.
original
_row_node_ids
,
self
.
original
_column_node_ids
,
)
index
=
_exclude_homo_edges
(
reverse_edges
,
edges
)
return
calling_class
(
*
_slice_subgraph
(
self
,
index
))
...
...
@@ -153,20 +153,20 @@ class SampledSubgraph:
index
=
{}
for
etype
,
pair
in
self
.
node_pairs
.
items
():
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
reverse
_row_node_ids
=
(
original
_row_node_ids
=
(
None
if
self
.
reverse
_row_node_ids
is
None
else
self
.
reverse
_row_node_ids
.
get
(
src_type
)
if
self
.
original
_row_node_ids
is
None
else
self
.
original
_row_node_ids
.
get
(
src_type
)
)
reverse
_column_node_ids
=
(
original
_column_node_ids
=
(
None
if
self
.
reverse
_column_node_ids
is
None
else
self
.
reverse
_column_node_ids
.
get
(
dst_type
)
if
self
.
original
_column_node_ids
is
None
else
self
.
original
_column_node_ids
.
get
(
dst_type
)
)
reverse_edges
=
_to_reverse_ids
(
pair
,
reverse
_row_node_ids
,
reverse
_column_node_ids
,
original
_row_node_ids
,
original
_column_node_ids
,
)
index
[
etype
]
=
_exclude_homo_edges
(
reverse_edges
,
edges
.
get
(
etype
)
...
...
@@ -174,12 +174,12 @@ class SampledSubgraph:
return
calling_class
(
*
_slice_subgraph
(
self
,
index
))
def
_to_reverse_ids
(
node_pair
,
reverse
_row_node_ids
,
reverse
_column_node_ids
):
def
_to_reverse_ids
(
node_pair
,
original
_row_node_ids
,
original
_column_node_ids
):
u
,
v
=
node_pair
if
reverse
_row_node_ids
is
not
None
:
u
=
reverse
_row_node_ids
[
u
]
if
reverse
_column_node_ids
is
not
None
:
v
=
reverse
_column_node_ids
[
v
]
if
original
_row_node_ids
is
not
None
:
u
=
original
_row_node_ids
[
u
]
if
original
_column_node_ids
is
not
None
:
v
=
original
_column_node_ids
[
v
]
return
(
u
,
v
)
...
...
@@ -224,7 +224,7 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
return
(
_index_select
(
subgraph
.
node_pairs
,
index
),
subgraph
.
reverse
_column_node_ids
,
subgraph
.
reverse
_row_node_ids
,
_index_select
(
subgraph
.
reverse
_edge_ids
,
index
),
subgraph
.
original
_column_node_ids
,
subgraph
.
original
_row_node_ids
,
_index_select
(
subgraph
.
original
_edge_ids
,
index
),
)
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
View file @
ed3840fc
...
...
@@ -397,12 +397,12 @@ def test_in_subgraph_homogeneous():
assert
torch
.
equal
(
in_subgraph
.
indices
,
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
])
)
assert
torch
.
equal
(
in_subgraph
.
reverse
_column_node_ids
,
nodes
)
assert
torch
.
equal
(
in_subgraph
.
original
_column_node_ids
,
nodes
)
assert
torch
.
equal
(
in_subgraph
.
reverse
_row_node_ids
,
torch
.
arange
(
0
,
num_nodes
)
in_subgraph
.
original
_row_node_ids
,
torch
.
arange
(
0
,
num_nodes
)
)
assert
torch
.
equal
(
in_subgraph
.
reverse
_edge_ids
,
torch
.
LongTensor
([
3
,
4
,
7
,
8
,
9
,
10
,
11
])
in_subgraph
.
original
_edge_ids
,
torch
.
LongTensor
([
3
,
4
,
7
,
8
,
9
,
10
,
11
])
)
assert
in_subgraph
.
type_per_edge
is
None
...
...
@@ -463,12 +463,12 @@ def test_in_subgraph_heterogeneous():
assert
torch
.
equal
(
in_subgraph
.
indices
,
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
])
)
assert
torch
.
equal
(
in_subgraph
.
reverse
_column_node_ids
,
nodes
)
assert
torch
.
equal
(
in_subgraph
.
original
_column_node_ids
,
nodes
)
assert
torch
.
equal
(
in_subgraph
.
reverse
_row_node_ids
,
torch
.
arange
(
0
,
num_nodes
)
in_subgraph
.
original
_row_node_ids
,
torch
.
arange
(
0
,
num_nodes
)
)
assert
torch
.
equal
(
in_subgraph
.
reverse
_edge_ids
,
torch
.
LongTensor
([
3
,
4
,
7
,
8
,
9
,
10
,
11
])
in_subgraph
.
original
_edge_ids
,
torch
.
LongTensor
([
3
,
4
,
7
,
8
,
9
,
10
,
11
])
)
assert
torch
.
equal
(
in_subgraph
.
type_per_edge
,
torch
.
LongTensor
([
2
,
2
,
1
,
3
,
1
,
3
,
3
])
...
...
@@ -505,9 +505,9 @@ def test_sample_neighbors_homo():
# Verify in subgraph.
sampled_num
=
subgraph
.
node_pairs
[
0
].
size
(
0
)
assert
sampled_num
==
6
assert
subgraph
.
reverse
_column_node_ids
is
None
assert
subgraph
.
reverse
_row_node_ids
is
None
assert
subgraph
.
reverse
_edge_ids
is
None
assert
subgraph
.
original
_column_node_ids
is
None
assert
subgraph
.
original
_row_node_ids
is
None
assert
subgraph
.
original
_edge_ids
is
None
@
unittest
.
skipIf
(
...
...
@@ -568,9 +568,9 @@ def test_sample_neighbors_hetero(labor):
for
etype
,
pairs
in
expected_node_pairs
.
items
():
assert
torch
.
equal
(
subgraph
.
node_pairs
[
etype
][
0
],
pairs
[
0
])
assert
torch
.
equal
(
subgraph
.
node_pairs
[
etype
][
1
],
pairs
[
1
])
assert
subgraph
.
reverse
_column_node_ids
is
None
assert
subgraph
.
reverse
_row_node_ids
is
None
assert
subgraph
.
reverse
_edge_ids
is
None
assert
subgraph
.
original
_column_node_ids
is
None
assert
subgraph
.
original
_row_node_ids
is
None
assert
subgraph
.
original
_edge_ids
is
None
# Sample on single node type.
nodes
=
{
"n1"
:
torch
.
LongTensor
([
0
])}
...
...
@@ -593,9 +593,9 @@ def test_sample_neighbors_hetero(labor):
for
etype
,
pairs
in
expected_node_pairs
.
items
():
assert
torch
.
equal
(
subgraph
.
node_pairs
[
etype
][
0
],
pairs
[
0
])
assert
torch
.
equal
(
subgraph
.
node_pairs
[
etype
][
1
],
pairs
[
1
])
assert
subgraph
.
reverse
_column_node_ids
is
None
assert
subgraph
.
reverse
_row_node_ids
is
None
assert
subgraph
.
reverse
_edge_ids
is
None
assert
subgraph
.
original
_column_node_ids
is
None
assert
subgraph
.
original
_row_node_ids
is
None
assert
subgraph
.
original
_edge_ids
is
None
@
unittest
.
skipIf
(
...
...
tests/python/pytorch/graphbolt/impl/test_minibatch.py
View file @
ed3840fc
...
...
@@ -13,11 +13,11 @@ def test_to_dgl_blocks_hetero():
},
{
relation
:
(
torch
.
tensor
([
0
,
1
]),
torch
.
tensor
([
1
,
0
]))},
]
reverse
_column_node_ids
=
[
original
_column_node_ids
=
[
{
"B"
:
torch
.
tensor
([
10
,
11
,
12
]),
"A"
:
torch
.
tensor
([
5
,
7
,
9
,
11
])},
{
"B"
:
torch
.
tensor
([
10
,
11
])},
]
reverse
_row_node_ids
=
[
original
_row_node_ids
=
[
{
"A"
:
torch
.
tensor
([
5
,
7
,
9
,
11
]),
"B"
:
torch
.
tensor
([
10
,
11
,
12
]),
...
...
@@ -27,7 +27,7 @@ def test_to_dgl_blocks_hetero():
"B"
:
torch
.
tensor
([
10
,
11
]),
},
]
reverse
_edge_ids
=
[
original
_edge_ids
=
[
{
relation
:
torch
.
tensor
([
19
,
20
,
21
]),
reverse_relation
:
torch
.
tensor
([
23
,
26
]),
...
...
@@ -46,9 +46,9 @@ def test_to_dgl_blocks_hetero():
subgraphs
.
append
(
gb
.
SampledSubgraphImpl
(
node_pairs
=
node_pairs
[
i
],
reverse
_column_node_ids
=
reverse
_column_node_ids
[
i
],
reverse
_row_node_ids
=
reverse
_row_node_ids
[
i
],
reverse
_edge_ids
=
reverse
_edge_ids
[
i
],
original
_column_node_ids
=
original
_column_node_ids
[
i
],
original
_row_node_ids
=
original
_row_node_ids
[
i
],
original
_edge_ids
=
original
_edge_ids
[
i
],
)
)
blocks
=
gb
.
MiniBatch
(
...
...
@@ -63,7 +63,7 @@ def test_to_dgl_blocks_hetero():
assert
torch
.
equal
(
edges
[
0
],
node_pairs
[
i
][
relation
][
0
])
assert
torch
.
equal
(
edges
[
1
],
node_pairs
[
i
][
relation
][
1
])
assert
torch
.
equal
(
block
.
edges
[
etype
].
data
[
dgl
.
EID
],
reverse
_edge_ids
[
i
][
relation
]
block
.
edges
[
etype
].
data
[
dgl
.
EID
],
original
_edge_ids
[
i
][
relation
]
)
assert
torch
.
equal
(
block
.
edges
[
etype
].
data
[
"x"
],
...
...
@@ -73,10 +73,10 @@ def test_to_dgl_blocks_hetero():
assert
torch
.
equal
(
edges
[
0
],
node_pairs
[
0
][
reverse_relation
][
0
])
assert
torch
.
equal
(
edges
[
1
],
node_pairs
[
0
][
reverse_relation
][
1
])
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
dgl
.
NID
][
"A"
],
reverse
_row_node_ids
[
0
][
"A"
]
blocks
[
0
].
srcdata
[
dgl
.
NID
][
"A"
],
original
_row_node_ids
[
0
][
"A"
]
)
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
dgl
.
NID
][
"B"
],
reverse
_row_node_ids
[
0
][
"B"
]
blocks
[
0
].
srcdata
[
dgl
.
NID
][
"B"
],
original
_row_node_ids
[
0
][
"B"
]
)
assert
torch
.
equal
(
blocks
[
0
].
srcnodes
[
"A"
].
data
[
"x"
],
node_features
[(
"A"
,
"x"
)]
...
...
@@ -97,15 +97,15 @@ def test_to_dgl_blocks_homo():
torch
.
tensor
([
1
,
0
,
0
]),
),
]
reverse
_column_node_ids
=
[
original
_column_node_ids
=
[
torch
.
tensor
([
10
,
11
,
12
,
13
]),
torch
.
tensor
([
10
,
11
]),
]
reverse
_row_node_ids
=
[
original
_row_node_ids
=
[
torch
.
tensor
([
10
,
11
,
12
,
13
]),
torch
.
tensor
([
10
,
11
,
12
]),
]
reverse
_edge_ids
=
[
original
_edge_ids
=
[
torch
.
tensor
([
19
,
20
,
21
,
22
,
25
,
30
]),
torch
.
tensor
([
10
,
15
,
17
]),
]
...
...
@@ -119,9 +119,9 @@ def test_to_dgl_blocks_homo():
subgraphs
.
append
(
gb
.
SampledSubgraphImpl
(
node_pairs
=
node_pairs
[
i
],
reverse
_column_node_ids
=
reverse
_column_node_ids
[
i
],
reverse
_row_node_ids
=
reverse
_row_node_ids
[
i
],
reverse
_edge_ids
=
reverse
_edge_ids
[
i
],
original
_column_node_ids
=
original
_column_node_ids
[
i
],
original
_row_node_ids
=
original
_row_node_ids
[
i
],
original
_edge_ids
=
original
_edge_ids
[
i
],
)
)
blocks
=
gb
.
MiniBatch
(
...
...
@@ -133,9 +133,9 @@ def test_to_dgl_blocks_homo():
for
i
,
block
in
enumerate
(
blocks
):
assert
torch
.
equal
(
block
.
edges
()[
0
],
node_pairs
[
i
][
0
])
assert
torch
.
equal
(
block
.
edges
()[
1
],
node_pairs
[
i
][
1
])
assert
torch
.
equal
(
block
.
edata
[
dgl
.
EID
],
reverse
_edge_ids
[
i
])
assert
torch
.
equal
(
block
.
edata
[
dgl
.
EID
],
original
_edge_ids
[
i
])
assert
torch
.
equal
(
block
.
edata
[
"x"
],
edge_features
[
i
][
"x"
])
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
dgl
.
NID
],
reverse
_row_node_ids
[
0
])
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
dgl
.
NID
],
original
_row_node_ids
[
0
])
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
"x"
],
node_features
[
"x"
])
...
...
@@ -150,15 +150,15 @@ def test_representation():
torch
.
tensor
([
1
,
0
,
0
]),
),
]
reverse
_column_node_ids
=
[
original
_column_node_ids
=
[
torch
.
tensor
([
10
,
11
,
12
,
13
]),
torch
.
tensor
([
10
,
11
]),
]
reverse
_row_node_ids
=
[
original
_row_node_ids
=
[
torch
.
tensor
([
10
,
11
,
12
,
13
]),
torch
.
tensor
([
10
,
11
,
12
]),
]
reverse
_edge_ids
=
[
original
_edge_ids
=
[
torch
.
tensor
([
19
,
20
,
21
,
22
,
25
,
30
]),
torch
.
tensor
([
10
,
15
,
17
]),
]
...
...
@@ -172,9 +172,9 @@ def test_representation():
subgraphs
.
append
(
gb
.
SampledSubgraphImpl
(
node_pairs
=
node_pairs
[
i
],
reverse
_column_node_ids
=
reverse
_column_node_ids
[
i
],
reverse
_row_node_ids
=
reverse
_row_node_ids
[
i
],
reverse
_edge_ids
=
reverse
_edge_ids
[
i
],
original
_column_node_ids
=
original
_column_node_ids
[
i
],
original
_row_node_ids
=
original
_row_node_ids
[
i
],
original
_edge_ids
=
original
_edge_ids
[
i
],
)
)
negative_srcs
=
torch
.
tensor
([[
8
],
[
1
],
[
6
]])
...
...
@@ -220,13 +220,13 @@ def test_representation():
expect_result
=
str
(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
reverse
_column_node_ids=tensor([10, 11, 12, 13]),
reverse
_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
reverse
_row_node_ids=tensor([10, 11, 12, 13]),),
original
_column_node_ids=tensor([10, 11, 12, 13]),
original
_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
original
_row_node_ids=tensor([10, 11, 12, 13]),),
SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2]), tensor([1, 0, 0])),
reverse
_column_node_ids=tensor([10, 11]),
reverse
_edge_ids=tensor([10, 15, 17]),
reverse
_row_node_ids=tensor([10, 11, 12]),)],
original
_column_node_ids=tensor([10, 11]),
original
_edge_ids=tensor([10, 15, 17]),
original
_row_node_ids=tensor([10, 11, 12]),)],
node_pairs=[(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
(tensor([0, 1, 2]), tensor([1, 0, 0]))],
node_features={'x': tensor([7, 6, 2, 2])},
...
...
tests/python/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
View file @
ed3840fc
...
...
@@ -26,24 +26,24 @@ def _assert_container_equal(lhs, rhs):
def
test_exclude_edges_homo
(
reverse_row
,
reverse_column
):
node_pairs
=
(
torch
.
tensor
([
0
,
2
,
3
]),
torch
.
tensor
([
1
,
4
,
2
]))
if
reverse_row
:
reverse
_row_node_ids
=
torch
.
tensor
([
10
,
15
,
11
,
24
,
9
])
original
_row_node_ids
=
torch
.
tensor
([
10
,
15
,
11
,
24
,
9
])
src_to_exclude
=
torch
.
tensor
([
11
])
else
:
reverse
_row_node_ids
=
None
original
_row_node_ids
=
None
src_to_exclude
=
torch
.
tensor
([
2
])
if
reverse_column
:
reverse
_column_node_ids
=
torch
.
tensor
([
10
,
15
,
11
,
24
,
9
])
original
_column_node_ids
=
torch
.
tensor
([
10
,
15
,
11
,
24
,
9
])
dst_to_exclude
=
torch
.
tensor
([
9
])
else
:
reverse
_column_node_ids
=
None
original
_column_node_ids
=
None
dst_to_exclude
=
torch
.
tensor
([
4
])
reverse
_edge_ids
=
torch
.
Tensor
([
5
,
9
,
10
])
original
_edge_ids
=
torch
.
Tensor
([
5
,
9
,
10
])
subgraph
=
SampledSubgraphImpl
(
node_pairs
,
reverse
_column_node_ids
,
reverse
_row_node_ids
,
reverse
_edge_ids
,
original
_column_node_ids
,
original
_row_node_ids
,
original
_edge_ids
,
)
edges_to_exclude
=
(
src_to_exclude
,
dst_to_exclude
)
result
=
subgraph
.
exclude_edges
(
edges_to_exclude
)
...
...
@@ -60,10 +60,10 @@ def test_exclude_edges_homo(reverse_row, reverse_column):
_assert_container_equal
(
result
.
node_pairs
,
expected_node_pairs
)
_assert_container_equal
(
result
.
reverse
_column_node_ids
,
expected_column_node_ids
result
.
original
_column_node_ids
,
expected_column_node_ids
)
_assert_container_equal
(
result
.
reverse
_row_node_ids
,
expected_row_node_ids
)
_assert_container_equal
(
result
.
reverse
_edge_ids
,
expected_edge_ids
)
_assert_container_equal
(
result
.
original
_row_node_ids
,
expected_row_node_ids
)
_assert_container_equal
(
result
.
original
_edge_ids
,
expected_edge_ids
)
@
pytest
.
mark
.
parametrize
(
"reverse_row"
,
[
True
,
False
])
...
...
@@ -76,27 +76,27 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
)
}
if
reverse_row
:
reverse
_row_node_ids
=
{
original
_row_node_ids
=
{
"A"
:
torch
.
tensor
([
13
,
14
,
15
]),
}
src_to_exclude
=
torch
.
tensor
([
15
,
13
])
else
:
reverse
_row_node_ids
=
None
original
_row_node_ids
=
None
src_to_exclude
=
torch
.
tensor
([
2
,
0
])
if
reverse_column
:
reverse
_column_node_ids
=
{
original
_column_node_ids
=
{
"B"
:
torch
.
tensor
([
10
,
11
,
12
]),
}
dst_to_exclude
=
torch
.
tensor
([
10
,
12
])
else
:
reverse
_column_node_ids
=
None
original
_column_node_ids
=
None
dst_to_exclude
=
torch
.
tensor
([
0
,
2
])
reverse
_edge_ids
=
{
"A:relation:B"
:
torch
.
tensor
([
19
,
20
,
21
])}
original
_edge_ids
=
{
"A:relation:B"
:
torch
.
tensor
([
19
,
20
,
21
])}
subgraph
=
SampledSubgraphImpl
(
node_pairs
=
node_pairs
,
reverse
_column_node_ids
=
reverse
_column_node_ids
,
reverse
_row_node_ids
=
reverse
_row_node_ids
,
reverse
_edge_ids
=
reverse
_edge_ids
,
original
_column_node_ids
=
original
_column_node_ids
,
original
_row_node_ids
=
original
_row_node_ids
,
original
_edge_ids
=
original
_edge_ids
,
)
edges_to_exclude
=
{
...
...
@@ -128,7 +128,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
_assert_container_equal
(
result
.
node_pairs
,
expected_node_pairs
)
_assert_container_equal
(
result
.
reverse
_column_node_ids
,
expected_column_node_ids
result
.
original
_column_node_ids
,
expected_column_node_ids
)
_assert_container_equal
(
result
.
reverse
_row_node_ids
,
expected_row_node_ids
)
_assert_container_equal
(
result
.
reverse
_edge_ids
,
expected_edge_ids
)
_assert_container_equal
(
result
.
original
_row_node_ids
,
expected_row_node_ids
)
_assert_container_equal
(
result
.
original
_edge_ids
,
expected_edge_ids
)
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
ed3840fc
...
...
@@ -65,7 +65,7 @@ def test_FeatureFetcher_with_edges_homo():
subgraphs
.
append
(
gb
.
SampledSubgraphImpl
(
node_pairs
=
(
torch
.
tensor
([]),
torch
.
tensor
([])),
reverse
_edge_ids
=
torch
.
randint
(
0
,
graph
.
num_edges
,
(
10
,)),
original
_edge_ids
=
torch
.
randint
(
0
,
graph
.
num_edges
,
(
10
,)),
)
)
data
=
gb
.
MiniBatch
(
input_nodes
=
seeds
,
sampled_subgraphs
=
subgraphs
)
...
...
@@ -146,7 +146,7 @@ def test_FeatureFetcher_with_edges_hetero():
def
add_node_and_edge_ids
(
seeds
):
subgraphs
=
[]
reverse
_edge_ids
=
{
original
_edge_ids
=
{
"n1:e1:n2"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
"n2:e2:n1"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
}
...
...
@@ -154,7 +154,7 @@ def test_FeatureFetcher_with_edges_hetero():
subgraphs
.
append
(
gb
.
SampledSubgraphImpl
(
node_pairs
=
(
torch
.
tensor
([]),
torch
.
tensor
([])),
reverse
_edge_ids
=
reverse
_edge_ids
,
original
_edge_ids
=
original
_edge_ids
,
)
)
data
=
gb
.
MiniBatch
(
input_nodes
=
seeds
,
sampled_subgraphs
=
subgraphs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment