Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ed3840fc
Unverified
Commit
ed3840fc
authored
Sep 22, 2023
by
peizhou001
Committed by
GitHub
Sep 22, 2023
Browse files
[Graphbolt] Replace reverse with original (#6371)
parent
7e49ccef
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
168 additions
and
165 deletions
+168
-165
examples/sampling/graphbolt/link_prediction.py
examples/sampling/graphbolt/link_prediction.py
+2
-2
graphbolt/include/graphbolt/sampled_subgraph.h
graphbolt/include/graphbolt/sampled_subgraph.h
+20
-20
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+4
-3
python/dgl/graphbolt/feature_fetcher.py
python/dgl/graphbolt/feature_fetcher.py
+3
-3
python/dgl/graphbolt/impl/csc_sampling_graph.py
python/dgl/graphbolt/impl/csc_sampling_graph.py
+1
-1
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+3
-3
python/dgl/graphbolt/impl/sampled_subgraph_impl.py
python/dgl/graphbolt/impl/sampled_subgraph_impl.py
+14
-12
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+16
-16
python/dgl/graphbolt/sampled_subgraph.py
python/dgl/graphbolt/sampled_subgraph.py
+36
-36
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
.../python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
+15
-15
tests/python/pytorch/graphbolt/impl/test_minibatch.py
tests/python/pytorch/graphbolt/impl/test_minibatch.py
+29
-29
tests/python/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
...thon/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
+22
-22
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+3
-3
No files found.
examples/sampling/graphbolt/link_prediction.py
View file @
ed3840fc
...
@@ -206,8 +206,8 @@ def to_dgl_blocks(sampled_subgraphs: gb.SampledSubgraphImpl):
...
@@ -206,8 +206,8 @@ def to_dgl_blocks(sampled_subgraphs: gb.SampledSubgraphImpl):
blocks
=
[
blocks
=
[
dgl
.
create_block
(
dgl
.
create_block
(
sampled_subgraph
.
node_pairs
,
sampled_subgraph
.
node_pairs
,
num_src_nodes
=
sampled_subgraph
.
reverse
_row_node_ids
.
shape
[
0
],
num_src_nodes
=
sampled_subgraph
.
original
_row_node_ids
.
shape
[
0
],
num_dst_nodes
=
sampled_subgraph
.
reverse
_column_node_ids
.
shape
[
0
],
num_dst_nodes
=
sampled_subgraph
.
original
_column_node_ids
.
shape
[
0
],
)
)
for
sampled_subgraph
in
sampled_subgraphs
for
sampled_subgraph
in
sampled_subgraphs
]
]
...
...
graphbolt/include/graphbolt/sampled_subgraph.h
View file @
ed3840fc
...
@@ -22,17 +22,17 @@ namespace sampling {
...
@@ -22,17 +22,17 @@ namespace sampling {
* ```
* ```
* auto indptr = torch::tensor({0, 2, 3, 4}, {torch::kInt64});
* auto indptr = torch::tensor({0, 2, 3, 4}, {torch::kInt64});
* auto indices = torch::tensor({55, 101, 3, 3}, {torch::kInt64});
* auto indices = torch::tensor({55, 101, 3, 3}, {torch::kInt64});
* auto
reverse
_column_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64});
* auto
original
_column_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64});
*
*
* SampledSubgraph sampledSubgraph(indptr, indices,
reverse
_column_node_ids);
* SampledSubgraph sampledSubgraph(indptr, indices,
original
_column_node_ids);
* ```
* ```
*
*
* The `
reverse
_column_node_ids` indicates that nodes `[3, 3, 101]` in the
* The `
original
_column_node_ids` indicates that nodes `[3, 3, 101]` in the
* original graph are mapped to `[0, 1, 2]` in this subgraph, and because
* original graph are mapped to `[0, 1, 2]` in this subgraph, and because
* `
reverse
_row_node_ids` is `Null`, `{55, 101, 3, 3}` in `indices` is just
* `
original
_row_node_ids` is `Null`, `{55, 101, 3, 3}` in `indices` is just
* the original node ids without compaction.
* the original node ids without compaction.
*
*
* If `
reverse
_row_node_ids = torch::tensor({55, 101, 3}, {torch::kInt64})`,
* If `
original
_row_node_ids = torch::tensor({55, 101, 3}, {torch::kInt64})`,
* it would indicate a different mapping for the row nodes. Note this is
* it would indicate a different mapping for the row nodes. Note this is
* inconsistent with column, which is legal, as `3` is mapped to `0` and `1` in
* inconsistent with column, which is legal, as `3` is mapped to `0` and `1` in
* the column while `2` in the row.
* the column while `2` in the row.
...
@@ -44,24 +44,24 @@ struct SampledSubgraph : torch::CustomClassHolder {
...
@@ -44,24 +44,24 @@ struct SampledSubgraph : torch::CustomClassHolder {
*
*
* @param indptr CSC format index pointer array.
* @param indptr CSC format index pointer array.
* @param indices CSC format index array.
* @param indices CSC format index array.
* @param
reverse
_column_node_ids Row's reverse node ids in the original
* @param
original
_column_node_ids Row's reverse node ids in the original
* graph.
* graph.
* @param
reverse
_row_node_ids Column's reverse node ids in the original
* @param
original
_row_node_ids Column's reverse node ids in the original
* graph.
* graph.
* @param
reverse
_edge_ids Reverse edge ids in the original graph.
* @param
original
_edge_ids Reverse edge ids in the original graph.
* @param type_per_edge Type id of each edge.
* @param type_per_edge Type id of each edge.
*/
*/
SampledSubgraph
(
SampledSubgraph
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
reverse
_column_node_ids
,
torch
::
Tensor
original
_column_node_ids
,
torch
::
optional
<
torch
::
Tensor
>
reverse
_row_node_ids
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
original
_row_node_ids
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
reverse
_edge_ids
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
original
_edge_ids
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
=
torch
::
nullopt
)
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
=
torch
::
nullopt
)
:
indptr
(
indptr
),
:
indptr
(
indptr
),
indices
(
indices
),
indices
(
indices
),
reverse
_column_node_ids
(
reverse
_column_node_ids
),
original
_column_node_ids
(
original
_column_node_ids
),
reverse
_row_node_ids
(
reverse
_row_node_ids
),
original
_row_node_ids
(
original
_row_node_ids
),
reverse
_edge_ids
(
reverse
_edge_ids
),
original
_edge_ids
(
original
_edge_ids
),
type_per_edge
(
type_per_edge
)
{}
type_per_edge
(
type_per_edge
)
{}
SampledSubgraph
()
=
default
;
SampledSubgraph
()
=
default
;
...
@@ -69,14 +69,14 @@ struct SampledSubgraph : torch::CustomClassHolder {
...
@@ -69,14 +69,14 @@ struct SampledSubgraph : torch::CustomClassHolder {
/**
/**
* @brief CSC format index pointer array, where the implicit node ids are
* @brief CSC format index pointer array, where the implicit node ids are
* already compacted. And the original ids are stored in the
* already compacted. And the original ids are stored in the
* `
reverse
_column_node_ids` field.
* `
original
_column_node_ids` field.
*/
*/
torch
::
Tensor
indptr
;
torch
::
Tensor
indptr
;
/**
/**
* @brief CSC format index array, where the node ids can be compacted ids or
* @brief CSC format index array, where the node ids can be compacted ids or
* original ids. If compacted, the original ids are stored in the
* original ids. If compacted, the original ids are stored in the
* `
reverse
_row_node_ids` field.
* `
original
_row_node_ids` field.
*/
*/
torch
::
Tensor
indices
;
torch
::
Tensor
indices
;
...
@@ -88,7 +88,7 @@ struct SampledSubgraph : torch::CustomClassHolder {
...
@@ -88,7 +88,7 @@ struct SampledSubgraph : torch::CustomClassHolder {
* @note This is required and the mapping relations can be inconsistent with
* @note This is required and the mapping relations can be inconsistent with
* column's.
* column's.
*/
*/
torch
::
Tensor
reverse
_column_node_ids
;
torch
::
Tensor
original
_column_node_ids
;
/**
/**
* @brief Row's reverse node ids in the original graph. A graph structure
* @brief Row's reverse node ids in the original graph. A graph structure
...
@@ -98,14 +98,14 @@ struct SampledSubgraph : torch::CustomClassHolder {
...
@@ -98,14 +98,14 @@ struct SampledSubgraph : torch::CustomClassHolder {
* @note This is optional and the mapping relations can be inconsistent with
* @note This is optional and the mapping relations can be inconsistent with
* row's.
* row's.
*/
*/
torch
::
optional
<
torch
::
Tensor
>
reverse
_row_node_ids
;
torch
::
optional
<
torch
::
Tensor
>
original
_row_node_ids
;
/**
/**
* @brief Reverse edge ids in the original graph, the edge with id
* @brief Reverse edge ids in the original graph, the edge with id
* `
reverse
_edge_ids[i]` in the original graph is mapped to `i` in this
* `
original
_edge_ids[i]` in the original graph is mapped to `i` in this
* subgraph. This is useful when edge features are needed.
* subgraph. This is useful when edge features are needed.
*/
*/
torch
::
optional
<
torch
::
Tensor
>
reverse
_edge_ids
;
torch
::
optional
<
torch
::
Tensor
>
original
_edge_ids
;
/**
/**
* @brief Type id of each edge, where type id is the corresponding index of
* @brief Type id of each edge, where type id is the corresponding index of
...
...
graphbolt/src/python_binding.cc
View file @
ed3840fc
...
@@ -17,10 +17,11 @@ TORCH_LIBRARY(graphbolt, m) {
...
@@ -17,10 +17,11 @@ TORCH_LIBRARY(graphbolt, m) {
.
def_readwrite
(
"indptr"
,
&
SampledSubgraph
::
indptr
)
.
def_readwrite
(
"indptr"
,
&
SampledSubgraph
::
indptr
)
.
def_readwrite
(
"indices"
,
&
SampledSubgraph
::
indices
)
.
def_readwrite
(
"indices"
,
&
SampledSubgraph
::
indices
)
.
def_readwrite
(
.
def_readwrite
(
"
reverse
_row_node_ids"
,
&
SampledSubgraph
::
reverse
_row_node_ids
)
"
original
_row_node_ids"
,
&
SampledSubgraph
::
original
_row_node_ids
)
.
def_readwrite
(
.
def_readwrite
(
"reverse_column_node_ids"
,
&
SampledSubgraph
::
reverse_column_node_ids
)
"original_column_node_ids"
,
.
def_readwrite
(
"reverse_edge_ids"
,
&
SampledSubgraph
::
reverse_edge_ids
)
&
SampledSubgraph
::
original_column_node_ids
)
.
def_readwrite
(
"original_edge_ids"
,
&
SampledSubgraph
::
original_edge_ids
)
.
def_readwrite
(
"type_per_edge"
,
&
SampledSubgraph
::
type_per_edge
);
.
def_readwrite
(
"type_per_edge"
,
&
SampledSubgraph
::
type_per_edge
);
m
.
class_
<
CSCSamplingGraph
>
(
"CSCSamplingGraph"
)
m
.
class_
<
CSCSamplingGraph
>
(
"CSCSamplingGraph"
)
.
def
(
"num_nodes"
,
&
CSCSamplingGraph
::
NumNodes
)
.
def
(
"num_nodes"
,
&
CSCSamplingGraph
::
NumNodes
)
...
...
python/dgl/graphbolt/feature_fetcher.py
View file @
ed3840fc
...
@@ -95,14 +95,14 @@ class FeatureFetcher(MiniBatchTransformer):
...
@@ -95,14 +95,14 @@ class FeatureFetcher(MiniBatchTransformer):
# Read Edge features.
# Read Edge features.
if
self
.
edge_feature_keys
and
data
.
sampled_subgraphs
:
if
self
.
edge_feature_keys
and
data
.
sampled_subgraphs
:
for
i
,
subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
for
i
,
subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
if
subgraph
.
reverse
_edge_ids
is
None
:
if
subgraph
.
original
_edge_ids
is
None
:
continue
continue
if
is_heterogeneous
:
if
is_heterogeneous
:
for
(
for
(
type_name
,
type_name
,
feature_names
,
feature_names
,
)
in
self
.
edge_feature_keys
.
items
():
)
in
self
.
edge_feature_keys
.
items
():
edges
=
subgraph
.
reverse
_edge_ids
.
get
(
type_name
,
None
)
edges
=
subgraph
.
original
_edge_ids
.
get
(
type_name
,
None
)
if
edges
is
None
:
if
edges
is
None
:
continue
continue
for
feature_name
in
feature_names
:
for
feature_name
in
feature_names
:
...
@@ -119,6 +119,6 @@ class FeatureFetcher(MiniBatchTransformer):
...
@@ -119,6 +119,6 @@ class FeatureFetcher(MiniBatchTransformer):
"edge"
,
"edge"
,
None
,
None
,
feature_name
,
feature_name
,
subgraph
.
reverse
_edge_ids
,
subgraph
.
original
_edge_ids
,
)
)
return
data
return
data
python/dgl/graphbolt/impl/csc_sampling_graph.py
View file @
ed3840fc
...
@@ -225,7 +225,7 @@ class CSCSamplingGraph:
...
@@ -225,7 +225,7 @@ class CSCSamplingGraph:
column_num
=
(
column_num
=
(
C_sampled_subgraph
.
indptr
[
1
:]
-
C_sampled_subgraph
.
indptr
[:
-
1
]
C_sampled_subgraph
.
indptr
[
1
:]
-
C_sampled_subgraph
.
indptr
[:
-
1
]
)
)
column
=
C_sampled_subgraph
.
reverse
_column_node_ids
.
repeat_interleave
(
column
=
C_sampled_subgraph
.
original
_column_node_ids
.
repeat_interleave
(
column_num
column_num
)
)
row
=
C_sampled_subgraph
.
indices
row
=
C_sampled_subgraph
.
indices
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
ed3840fc
...
@@ -106,14 +106,14 @@ class NeighborSampler(SubgraphSampler):
...
@@ -106,14 +106,14 @@ class NeighborSampler(SubgraphSampler):
self
.
replace
,
self
.
replace
,
self
.
prob_name
,
self
.
prob_name
,
)
)
reverse
_column_node_ids
=
seeds
original
_column_node_ids
=
seeds
seeds
,
compacted_node_pairs
=
unique_and_compact_node_pairs
(
seeds
,
compacted_node_pairs
=
unique_and_compact_node_pairs
(
subgraph
.
node_pairs
,
seeds
subgraph
.
node_pairs
,
seeds
)
)
subgraph
=
SampledSubgraphImpl
(
subgraph
=
SampledSubgraphImpl
(
node_pairs
=
compacted_node_pairs
,
node_pairs
=
compacted_node_pairs
,
reverse
_column_node_ids
=
reverse
_column_node_ids
,
original
_column_node_ids
=
original
_column_node_ids
,
reverse
_row_node_ids
=
seeds
,
original
_row_node_ids
=
seeds
,
)
)
subgraphs
.
insert
(
0
,
subgraph
)
subgraphs
.
insert
(
0
,
subgraph
)
return
seeds
,
subgraphs
return
seeds
,
subgraphs
...
...
python/dgl/graphbolt/impl/sampled_subgraph_impl.py
View file @
ed3840fc
...
@@ -17,31 +17,33 @@ class SampledSubgraphImpl(SampledSubgraph):
...
@@ -17,31 +17,33 @@ class SampledSubgraphImpl(SampledSubgraph):
--------
--------
>>> node_pairs = {"A:relation:B"): (torch.tensor([0, 1, 2]),
>>> node_pairs = {"A:relation:B"): (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))}
... torch.tensor([0, 1, 2]))}
>>>
reverse
_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>>
original
_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>>
reverse
_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>>
original
_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>>
reverse
_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>>
original
_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... node_pairs=node_pairs,
...
reverse
_column_node_ids=
reverse
_column_node_ids,
...
original
_column_node_ids=
original
_column_node_ids,
...
reverse
_row_node_ids=
reverse
_row_node_ids,
...
original
_row_node_ids=
original
_row_node_ids,
...
reverse
_edge_ids=
reverse
_edge_ids
...
original
_edge_ids=
original
_edge_ids
... )
... )
>>> print(subgraph.node_pairs)
>>> print(subgraph.node_pairs)
{"A:relation:B": (tensor([0, 1, 2]), tensor([0, 1, 2]))}
{"A:relation:B": (tensor([0, 1, 2]), tensor([0, 1, 2]))}
>>> print(subgraph.
reverse
_column_node_ids)
>>> print(subgraph.
original
_column_node_ids)
{'B': tensor([10, 11, 12])}
{'B': tensor([10, 11, 12])}
>>> print(subgraph.
reverse
_row_node_ids)
>>> print(subgraph.
original
_row_node_ids)
{'A': tensor([13, 14, 15])}
{'A': tensor([13, 14, 15])}
>>> print(subgraph.
reverse
_edge_ids)
>>> print(subgraph.
original
_edge_ids)
{"A:relation:B": tensor([19, 20, 21])}
{"A:relation:B": tensor([19, 20, 21])}
"""
"""
node_pairs
:
Union
[
node_pairs
:
Union
[
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
]
=
None
]
=
None
reverse_column_node_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
original_column_node_ids
:
Union
[
reverse_row_node_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
reverse_edge_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
]
=
None
original_row_node_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
original_edge_ids
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
isinstance
(
self
.
node_pairs
,
dict
):
if
isinstance
(
self
.
node_pairs
,
dict
):
...
...
python/dgl/graphbolt/minibatch.py
View file @
ed3840fc
...
@@ -140,14 +140,14 @@ class MiniBatch:
...
@@ -140,14 +140,14 @@ class MiniBatch:
blocks
=
[]
blocks
=
[]
for
subgraph
in
self
.
sampled_subgraphs
:
for
subgraph
in
self
.
sampled_subgraphs
:
reverse
_row_node_ids
=
subgraph
.
reverse
_row_node_ids
original
_row_node_ids
=
subgraph
.
original
_row_node_ids
assert
(
assert
(
reverse
_row_node_ids
is
not
None
original
_row_node_ids
is
not
None
),
"Missing `
reverse
_row_node_ids` in sampled subgraph."
),
"Missing `
original
_row_node_ids` in sampled subgraph."
reverse
_column_node_ids
=
subgraph
.
reverse
_column_node_ids
original
_column_node_ids
=
subgraph
.
original
_column_node_ids
assert
(
assert
(
reverse
_column_node_ids
is
not
None
original
_column_node_ids
is
not
None
),
"Missing `
reverse
_column_node_ids` in sampled subgraph."
),
"Missing `
original
_column_node_ids` in sampled subgraph."
if
is_heterogeneous
:
if
is_heterogeneous
:
node_pairs
=
{
node_pairs
=
{
etype_str_to_tuple
(
etype
):
v
etype_str_to_tuple
(
etype
):
v
...
@@ -155,16 +155,16 @@ class MiniBatch:
...
@@ -155,16 +155,16 @@ class MiniBatch:
}
}
num_src_nodes
=
{
num_src_nodes
=
{
ntype
:
nodes
.
size
(
0
)
ntype
:
nodes
.
size
(
0
)
for
ntype
,
nodes
in
reverse
_row_node_ids
.
items
()
for
ntype
,
nodes
in
original
_row_node_ids
.
items
()
}
}
num_dst_nodes
=
{
num_dst_nodes
=
{
ntype
:
nodes
.
size
(
0
)
ntype
:
nodes
.
size
(
0
)
for
ntype
,
nodes
in
reverse
_column_node_ids
.
items
()
for
ntype
,
nodes
in
original
_column_node_ids
.
items
()
}
}
else
:
else
:
node_pairs
=
subgraph
.
node_pairs
node_pairs
=
subgraph
.
node_pairs
num_src_nodes
=
reverse
_row_node_ids
.
size
(
0
)
num_src_nodes
=
original
_row_node_ids
.
size
(
0
)
num_dst_nodes
=
reverse
_column_node_ids
.
size
(
0
)
num_dst_nodes
=
original
_column_node_ids
.
size
(
0
)
blocks
.
append
(
blocks
.
append
(
dgl
.
create_block
(
dgl
.
create_block
(
node_pairs
,
node_pairs
,
...
@@ -194,15 +194,15 @@ class MiniBatch:
...
@@ -194,15 +194,15 @@ class MiniBatch:
# Assign reverse node ids to the outermost layer's source nodes.
# Assign reverse node ids to the outermost layer's source nodes.
for
node_type
,
reverse_ids
in
self
.
sampled_subgraphs
[
for
node_type
,
reverse_ids
in
self
.
sampled_subgraphs
[
0
0
].
reverse
_row_node_ids
.
items
():
].
original
_row_node_ids
.
items
():
blocks
[
0
].
srcnodes
[
node_type
].
data
[
dgl
.
NID
]
=
reverse_ids
blocks
[
0
].
srcnodes
[
node_type
].
data
[
dgl
.
NID
]
=
reverse_ids
# Assign reverse edges ids.
# Assign reverse edges ids.
for
block
,
subgraph
in
zip
(
blocks
,
self
.
sampled_subgraphs
):
for
block
,
subgraph
in
zip
(
blocks
,
self
.
sampled_subgraphs
):
if
subgraph
.
reverse
_edge_ids
:
if
subgraph
.
original
_edge_ids
:
for
(
for
(
edge_type
,
edge_type
,
reverse_ids
,
reverse_ids
,
)
in
subgraph
.
reverse
_edge_ids
.
items
():
)
in
subgraph
.
original
_edge_ids
.
items
():
block
.
edges
[
etype_str_to_tuple
(
edge_type
)].
data
[
block
.
edges
[
etype_str_to_tuple
(
edge_type
)].
data
[
dgl
.
EID
dgl
.
EID
]
=
reverse_ids
]
=
reverse_ids
...
@@ -218,11 +218,11 @@ class MiniBatch:
...
@@ -218,11 +218,11 @@ class MiniBatch:
block
.
edata
[
feature_name
]
=
feature
block
.
edata
[
feature_name
]
=
feature
blocks
[
0
].
srcdata
[
dgl
.
NID
]
=
self
.
sampled_subgraphs
[
blocks
[
0
].
srcdata
[
dgl
.
NID
]
=
self
.
sampled_subgraphs
[
0
0
].
reverse
_row_node_ids
].
original
_row_node_ids
# Assign reverse edges ids.
# Assign reverse edges ids.
for
block
,
subgraph
in
zip
(
blocks
,
self
.
sampled_subgraphs
):
for
block
,
subgraph
in
zip
(
blocks
,
self
.
sampled_subgraphs
):
if
subgraph
.
reverse
_edge_ids
is
not
None
:
if
subgraph
.
original
_edge_ids
is
not
None
:
block
.
edata
[
dgl
.
EID
]
=
subgraph
.
reverse
_edge_ids
block
.
edata
[
dgl
.
EID
]
=
subgraph
.
original
_edge_ids
return
blocks
return
blocks
...
...
python/dgl/graphbolt/sampled_subgraph.py
View file @
ed3840fc
...
@@ -29,16 +29,16 @@ class SampledSubgraph:
...
@@ -29,16 +29,16 @@ class SampledSubgraph:
raise
NotImplementedError
raise
NotImplementedError
@
property
@
property
def
reverse
_column_node_ids
(
def
original
_column_node_ids
(
self
,
self
,
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""Returns corresponding reverse column node ids the original graph.
"""Returns corresponding reverse column node ids the original graph.
Column's reverse node ids in the original graph. A graph structure
Column's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is
can be treated as a coordinated row and column pair, and this is
the mapped ids of the column.
the mapped ids of the column.
- If `
reverse
_column_node_ids` is a tensor: It represents the
- If `
original
_column_node_ids` is a tensor: It represents the
original node ids.
original node ids.
- If `
reverse
_column_node_ids` is a dictionary: The keys should be
- If `
original
_column_node_ids` is a dictionary: The keys should be
node type and the values should be corresponding original
node type and the values should be corresponding original
heterogeneous node ids.
heterogeneous node ids.
If present, it means column IDs are compacted, and `node_pairs`
If present, it means column IDs are compacted, and `node_pairs`
...
@@ -47,16 +47,16 @@ class SampledSubgraph:
...
@@ -47,16 +47,16 @@ class SampledSubgraph:
return
None
return
None
@
property
@
property
def
reverse
_row_node_ids
(
def
original
_row_node_ids
(
self
,
self
,
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""Returns corresponding reverse row node ids the original graph.
"""Returns corresponding reverse row node ids the original graph.
Row's reverse node ids in the original graph. A graph structure
Row's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is
can be treated as a coordinated row and column pair, and this is
the mapped ids of the row.
the mapped ids of the row.
- If `
reverse
_row_node_ids` is a tensor: It represents the
- If `
original
_row_node_ids` is a tensor: It represents the
original node ids.
original node ids.
- If `
reverse
_row_node_ids` is a dictionary: The keys should be
- If `
original
_row_node_ids` is a dictionary: The keys should be
node type and the values should be corresponding original
node type and the values should be corresponding original
heterogeneous node ids.
heterogeneous node ids.
If present, it means row IDs are compacted, and `node_pairs`
If present, it means row IDs are compacted, and `node_pairs`
...
@@ -64,13 +64,13 @@ class SampledSubgraph:
...
@@ -64,13 +64,13 @@ class SampledSubgraph:
return
None
return
None
@
property
@
property
def
reverse
_edge_ids
(
self
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
def
original
_edge_ids
(
self
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""Returns corresponding reverse edge ids the original graph.
"""Returns corresponding reverse edge ids the original graph.
Reverse edge ids in the original graph. This is useful when edge
Reverse edge ids in the original graph. This is useful when edge
features are needed.
features are needed.
- If `
reverse
_edge_ids` is a tensor: It represents the
- If `
original
_edge_ids` is a tensor: It represents the
original edge ids.
original edge ids.
- If `
reverse
_edge_ids` is a dictionary: The keys should be
- If `
original
_edge_ids` is a dictionary: The keys should be
edge type and the values should be corresponding original
edge type and the values should be corresponding original
heterogeneous edge ids.
heterogeneous edge ids.
"""
"""
...
@@ -110,24 +110,24 @@ class SampledSubgraph:
...
@@ -110,24 +110,24 @@ class SampledSubgraph:
--------
--------
>>> node_pairs = {"A:relation:B": (torch.tensor([0, 1, 2]),
>>> node_pairs = {"A:relation:B": (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))}
... torch.tensor([0, 1, 2]))}
>>>
reverse
_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>>
original
_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>>
reverse
_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>>
original
_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>>
reverse
_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>>
original
_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... node_pairs=node_pairs,
...
reverse
_column_node_ids=
reverse
_column_node_ids,
...
original
_column_node_ids=
original
_column_node_ids,
...
reverse
_row_node_ids=
reverse
_row_node_ids,
...
original
_row_node_ids=
original
_row_node_ids,
...
reverse
_edge_ids=
reverse
_edge_ids
...
original
_edge_ids=
original
_edge_ids
... )
... )
>>> edges_to_exclude = (torch.tensor([14, 15]), torch.tensor([11, 12]))
>>> edges_to_exclude = (torch.tensor([14, 15]), torch.tensor([11, 12]))
>>> result = subgraph.exclude_edges(edges_to_exclude)
>>> result = subgraph.exclude_edges(edges_to_exclude)
>>> print(result.node_pairs)
>>> print(result.node_pairs)
{"A:relation:B": (tensor([0]), tensor([0]))}
{"A:relation:B": (tensor([0]), tensor([0]))}
>>> print(result.
reverse
_column_node_ids)
>>> print(result.
original
_column_node_ids)
{'B': tensor([10, 11, 12])}
{'B': tensor([10, 11, 12])}
>>> print(result.
reverse
_row_node_ids)
>>> print(result.
original
_row_node_ids)
{'A': tensor([13, 14, 15])}
{'A': tensor([13, 14, 15])}
>>> print(result.
reverse
_edge_ids)
>>> print(result.
original
_edge_ids)
{"A:relation:B": tensor([19])}
{"A:relation:B": tensor([19])}
"""
"""
assert
isinstance
(
self
.
node_pairs
,
tuple
)
==
isinstance
(
edges
,
tuple
),
(
assert
isinstance
(
self
.
node_pairs
,
tuple
)
==
isinstance
(
edges
,
tuple
),
(
...
@@ -144,8 +144,8 @@ class SampledSubgraph:
...
@@ -144,8 +144,8 @@ class SampledSubgraph:
if
isinstance
(
self
.
node_pairs
,
tuple
):
if
isinstance
(
self
.
node_pairs
,
tuple
):
reverse_edges
=
_to_reverse_ids
(
reverse_edges
=
_to_reverse_ids
(
self
.
node_pairs
,
self
.
node_pairs
,
self
.
reverse
_row_node_ids
,
self
.
original
_row_node_ids
,
self
.
reverse
_column_node_ids
,
self
.
original
_column_node_ids
,
)
)
index
=
_exclude_homo_edges
(
reverse_edges
,
edges
)
index
=
_exclude_homo_edges
(
reverse_edges
,
edges
)
return
calling_class
(
*
_slice_subgraph
(
self
,
index
))
return
calling_class
(
*
_slice_subgraph
(
self
,
index
))
...
@@ -153,20 +153,20 @@ class SampledSubgraph:
...
@@ -153,20 +153,20 @@ class SampledSubgraph:
index
=
{}
index
=
{}
for
etype
,
pair
in
self
.
node_pairs
.
items
():
for
etype
,
pair
in
self
.
node_pairs
.
items
():
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
reverse
_row_node_ids
=
(
original
_row_node_ids
=
(
None
None
if
self
.
reverse
_row_node_ids
is
None
if
self
.
original
_row_node_ids
is
None
else
self
.
reverse
_row_node_ids
.
get
(
src_type
)
else
self
.
original
_row_node_ids
.
get
(
src_type
)
)
)
reverse
_column_node_ids
=
(
original
_column_node_ids
=
(
None
None
if
self
.
reverse
_column_node_ids
is
None
if
self
.
original
_column_node_ids
is
None
else
self
.
reverse
_column_node_ids
.
get
(
dst_type
)
else
self
.
original
_column_node_ids
.
get
(
dst_type
)
)
)
reverse_edges
=
_to_reverse_ids
(
reverse_edges
=
_to_reverse_ids
(
pair
,
pair
,
reverse
_row_node_ids
,
original
_row_node_ids
,
reverse
_column_node_ids
,
original
_column_node_ids
,
)
)
index
[
etype
]
=
_exclude_homo_edges
(
index
[
etype
]
=
_exclude_homo_edges
(
reverse_edges
,
edges
.
get
(
etype
)
reverse_edges
,
edges
.
get
(
etype
)
...
@@ -174,12 +174,12 @@ class SampledSubgraph:
...
@@ -174,12 +174,12 @@ class SampledSubgraph:
return
calling_class
(
*
_slice_subgraph
(
self
,
index
))
return
calling_class
(
*
_slice_subgraph
(
self
,
index
))
def
_to_reverse_ids
(
node_pair
,
reverse
_row_node_ids
,
reverse
_column_node_ids
):
def
_to_reverse_ids
(
node_pair
,
original
_row_node_ids
,
original
_column_node_ids
):
u
,
v
=
node_pair
u
,
v
=
node_pair
if
reverse
_row_node_ids
is
not
None
:
if
original
_row_node_ids
is
not
None
:
u
=
reverse
_row_node_ids
[
u
]
u
=
original
_row_node_ids
[
u
]
if
reverse
_column_node_ids
is
not
None
:
if
original
_column_node_ids
is
not
None
:
v
=
reverse
_column_node_ids
[
v
]
v
=
original
_column_node_ids
[
v
]
return
(
u
,
v
)
return
(
u
,
v
)
...
@@ -224,7 +224,7 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
...
@@ -224,7 +224,7 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
return
(
return
(
_index_select
(
subgraph
.
node_pairs
,
index
),
_index_select
(
subgraph
.
node_pairs
,
index
),
subgraph
.
reverse
_column_node_ids
,
subgraph
.
original
_column_node_ids
,
subgraph
.
reverse
_row_node_ids
,
subgraph
.
original
_row_node_ids
,
_index_select
(
subgraph
.
reverse
_edge_ids
,
index
),
_index_select
(
subgraph
.
original
_edge_ids
,
index
),
)
)
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
View file @
ed3840fc
...
@@ -397,12 +397,12 @@ def test_in_subgraph_homogeneous():
...
@@ -397,12 +397,12 @@ def test_in_subgraph_homogeneous():
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
indices
,
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
])
in_subgraph
.
indices
,
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
])
)
)
assert
torch
.
equal
(
in_subgraph
.
reverse
_column_node_ids
,
nodes
)
assert
torch
.
equal
(
in_subgraph
.
original
_column_node_ids
,
nodes
)
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
reverse
_row_node_ids
,
torch
.
arange
(
0
,
num_nodes
)
in_subgraph
.
original
_row_node_ids
,
torch
.
arange
(
0
,
num_nodes
)
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
reverse
_edge_ids
,
torch
.
LongTensor
([
3
,
4
,
7
,
8
,
9
,
10
,
11
])
in_subgraph
.
original
_edge_ids
,
torch
.
LongTensor
([
3
,
4
,
7
,
8
,
9
,
10
,
11
])
)
)
assert
in_subgraph
.
type_per_edge
is
None
assert
in_subgraph
.
type_per_edge
is
None
...
@@ -463,12 +463,12 @@ def test_in_subgraph_heterogeneous():
...
@@ -463,12 +463,12 @@ def test_in_subgraph_heterogeneous():
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
indices
,
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
])
in_subgraph
.
indices
,
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
])
)
)
assert
torch
.
equal
(
in_subgraph
.
reverse
_column_node_ids
,
nodes
)
assert
torch
.
equal
(
in_subgraph
.
original
_column_node_ids
,
nodes
)
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
reverse
_row_node_ids
,
torch
.
arange
(
0
,
num_nodes
)
in_subgraph
.
original
_row_node_ids
,
torch
.
arange
(
0
,
num_nodes
)
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
reverse
_edge_ids
,
torch
.
LongTensor
([
3
,
4
,
7
,
8
,
9
,
10
,
11
])
in_subgraph
.
original
_edge_ids
,
torch
.
LongTensor
([
3
,
4
,
7
,
8
,
9
,
10
,
11
])
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
type_per_edge
,
torch
.
LongTensor
([
2
,
2
,
1
,
3
,
1
,
3
,
3
])
in_subgraph
.
type_per_edge
,
torch
.
LongTensor
([
2
,
2
,
1
,
3
,
1
,
3
,
3
])
...
@@ -505,9 +505,9 @@ def test_sample_neighbors_homo():
...
@@ -505,9 +505,9 @@ def test_sample_neighbors_homo():
# Verify in subgraph.
# Verify in subgraph.
sampled_num
=
subgraph
.
node_pairs
[
0
].
size
(
0
)
sampled_num
=
subgraph
.
node_pairs
[
0
].
size
(
0
)
assert
sampled_num
==
6
assert
sampled_num
==
6
assert
subgraph
.
reverse
_column_node_ids
is
None
assert
subgraph
.
original
_column_node_ids
is
None
assert
subgraph
.
reverse
_row_node_ids
is
None
assert
subgraph
.
original
_row_node_ids
is
None
assert
subgraph
.
reverse
_edge_ids
is
None
assert
subgraph
.
original
_edge_ids
is
None
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
...
@@ -568,9 +568,9 @@ def test_sample_neighbors_hetero(labor):
...
@@ -568,9 +568,9 @@ def test_sample_neighbors_hetero(labor):
for
etype
,
pairs
in
expected_node_pairs
.
items
():
for
etype
,
pairs
in
expected_node_pairs
.
items
():
assert
torch
.
equal
(
subgraph
.
node_pairs
[
etype
][
0
],
pairs
[
0
])
assert
torch
.
equal
(
subgraph
.
node_pairs
[
etype
][
0
],
pairs
[
0
])
assert
torch
.
equal
(
subgraph
.
node_pairs
[
etype
][
1
],
pairs
[
1
])
assert
torch
.
equal
(
subgraph
.
node_pairs
[
etype
][
1
],
pairs
[
1
])
assert
subgraph
.
reverse
_column_node_ids
is
None
assert
subgraph
.
original
_column_node_ids
is
None
assert
subgraph
.
reverse
_row_node_ids
is
None
assert
subgraph
.
original
_row_node_ids
is
None
assert
subgraph
.
reverse
_edge_ids
is
None
assert
subgraph
.
original
_edge_ids
is
None
# Sample on single node type.
# Sample on single node type.
nodes
=
{
"n1"
:
torch
.
LongTensor
([
0
])}
nodes
=
{
"n1"
:
torch
.
LongTensor
([
0
])}
...
@@ -593,9 +593,9 @@ def test_sample_neighbors_hetero(labor):
...
@@ -593,9 +593,9 @@ def test_sample_neighbors_hetero(labor):
for
etype
,
pairs
in
expected_node_pairs
.
items
():
for
etype
,
pairs
in
expected_node_pairs
.
items
():
assert
torch
.
equal
(
subgraph
.
node_pairs
[
etype
][
0
],
pairs
[
0
])
assert
torch
.
equal
(
subgraph
.
node_pairs
[
etype
][
0
],
pairs
[
0
])
assert
torch
.
equal
(
subgraph
.
node_pairs
[
etype
][
1
],
pairs
[
1
])
assert
torch
.
equal
(
subgraph
.
node_pairs
[
etype
][
1
],
pairs
[
1
])
assert
subgraph
.
reverse
_column_node_ids
is
None
assert
subgraph
.
original
_column_node_ids
is
None
assert
subgraph
.
reverse
_row_node_ids
is
None
assert
subgraph
.
original
_row_node_ids
is
None
assert
subgraph
.
reverse
_edge_ids
is
None
assert
subgraph
.
original
_edge_ids
is
None
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
...
...
tests/python/pytorch/graphbolt/impl/test_minibatch.py
View file @
ed3840fc
...
@@ -13,11 +13,11 @@ def test_to_dgl_blocks_hetero():
...
@@ -13,11 +13,11 @@ def test_to_dgl_blocks_hetero():
},
},
{
relation
:
(
torch
.
tensor
([
0
,
1
]),
torch
.
tensor
([
1
,
0
]))},
{
relation
:
(
torch
.
tensor
([
0
,
1
]),
torch
.
tensor
([
1
,
0
]))},
]
]
reverse
_column_node_ids
=
[
original
_column_node_ids
=
[
{
"B"
:
torch
.
tensor
([
10
,
11
,
12
]),
"A"
:
torch
.
tensor
([
5
,
7
,
9
,
11
])},
{
"B"
:
torch
.
tensor
([
10
,
11
,
12
]),
"A"
:
torch
.
tensor
([
5
,
7
,
9
,
11
])},
{
"B"
:
torch
.
tensor
([
10
,
11
])},
{
"B"
:
torch
.
tensor
([
10
,
11
])},
]
]
reverse
_row_node_ids
=
[
original
_row_node_ids
=
[
{
{
"A"
:
torch
.
tensor
([
5
,
7
,
9
,
11
]),
"A"
:
torch
.
tensor
([
5
,
7
,
9
,
11
]),
"B"
:
torch
.
tensor
([
10
,
11
,
12
]),
"B"
:
torch
.
tensor
([
10
,
11
,
12
]),
...
@@ -27,7 +27,7 @@ def test_to_dgl_blocks_hetero():
...
@@ -27,7 +27,7 @@ def test_to_dgl_blocks_hetero():
"B"
:
torch
.
tensor
([
10
,
11
]),
"B"
:
torch
.
tensor
([
10
,
11
]),
},
},
]
]
reverse
_edge_ids
=
[
original
_edge_ids
=
[
{
{
relation
:
torch
.
tensor
([
19
,
20
,
21
]),
relation
:
torch
.
tensor
([
19
,
20
,
21
]),
reverse_relation
:
torch
.
tensor
([
23
,
26
]),
reverse_relation
:
torch
.
tensor
([
23
,
26
]),
...
@@ -46,9 +46,9 @@ def test_to_dgl_blocks_hetero():
...
@@ -46,9 +46,9 @@ def test_to_dgl_blocks_hetero():
subgraphs
.
append
(
subgraphs
.
append
(
gb
.
SampledSubgraphImpl
(
gb
.
SampledSubgraphImpl
(
node_pairs
=
node_pairs
[
i
],
node_pairs
=
node_pairs
[
i
],
reverse
_column_node_ids
=
reverse
_column_node_ids
[
i
],
original
_column_node_ids
=
original
_column_node_ids
[
i
],
reverse
_row_node_ids
=
reverse
_row_node_ids
[
i
],
original
_row_node_ids
=
original
_row_node_ids
[
i
],
reverse
_edge_ids
=
reverse
_edge_ids
[
i
],
original
_edge_ids
=
original
_edge_ids
[
i
],
)
)
)
)
blocks
=
gb
.
MiniBatch
(
blocks
=
gb
.
MiniBatch
(
...
@@ -63,7 +63,7 @@ def test_to_dgl_blocks_hetero():
...
@@ -63,7 +63,7 @@ def test_to_dgl_blocks_hetero():
assert
torch
.
equal
(
edges
[
0
],
node_pairs
[
i
][
relation
][
0
])
assert
torch
.
equal
(
edges
[
0
],
node_pairs
[
i
][
relation
][
0
])
assert
torch
.
equal
(
edges
[
1
],
node_pairs
[
i
][
relation
][
1
])
assert
torch
.
equal
(
edges
[
1
],
node_pairs
[
i
][
relation
][
1
])
assert
torch
.
equal
(
assert
torch
.
equal
(
block
.
edges
[
etype
].
data
[
dgl
.
EID
],
reverse
_edge_ids
[
i
][
relation
]
block
.
edges
[
etype
].
data
[
dgl
.
EID
],
original
_edge_ids
[
i
][
relation
]
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
block
.
edges
[
etype
].
data
[
"x"
],
block
.
edges
[
etype
].
data
[
"x"
],
...
@@ -73,10 +73,10 @@ def test_to_dgl_blocks_hetero():
...
@@ -73,10 +73,10 @@ def test_to_dgl_blocks_hetero():
assert
torch
.
equal
(
edges
[
0
],
node_pairs
[
0
][
reverse_relation
][
0
])
assert
torch
.
equal
(
edges
[
0
],
node_pairs
[
0
][
reverse_relation
][
0
])
assert
torch
.
equal
(
edges
[
1
],
node_pairs
[
0
][
reverse_relation
][
1
])
assert
torch
.
equal
(
edges
[
1
],
node_pairs
[
0
][
reverse_relation
][
1
])
assert
torch
.
equal
(
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
dgl
.
NID
][
"A"
],
reverse
_row_node_ids
[
0
][
"A"
]
blocks
[
0
].
srcdata
[
dgl
.
NID
][
"A"
],
original
_row_node_ids
[
0
][
"A"
]
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
dgl
.
NID
][
"B"
],
reverse
_row_node_ids
[
0
][
"B"
]
blocks
[
0
].
srcdata
[
dgl
.
NID
][
"B"
],
original
_row_node_ids
[
0
][
"B"
]
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
blocks
[
0
].
srcnodes
[
"A"
].
data
[
"x"
],
node_features
[(
"A"
,
"x"
)]
blocks
[
0
].
srcnodes
[
"A"
].
data
[
"x"
],
node_features
[(
"A"
,
"x"
)]
...
@@ -97,15 +97,15 @@ def test_to_dgl_blocks_homo():
...
@@ -97,15 +97,15 @@ def test_to_dgl_blocks_homo():
torch
.
tensor
([
1
,
0
,
0
]),
torch
.
tensor
([
1
,
0
,
0
]),
),
),
]
]
reverse
_column_node_ids
=
[
original
_column_node_ids
=
[
torch
.
tensor
([
10
,
11
,
12
,
13
]),
torch
.
tensor
([
10
,
11
,
12
,
13
]),
torch
.
tensor
([
10
,
11
]),
torch
.
tensor
([
10
,
11
]),
]
]
reverse
_row_node_ids
=
[
original
_row_node_ids
=
[
torch
.
tensor
([
10
,
11
,
12
,
13
]),
torch
.
tensor
([
10
,
11
,
12
,
13
]),
torch
.
tensor
([
10
,
11
,
12
]),
torch
.
tensor
([
10
,
11
,
12
]),
]
]
reverse
_edge_ids
=
[
original
_edge_ids
=
[
torch
.
tensor
([
19
,
20
,
21
,
22
,
25
,
30
]),
torch
.
tensor
([
19
,
20
,
21
,
22
,
25
,
30
]),
torch
.
tensor
([
10
,
15
,
17
]),
torch
.
tensor
([
10
,
15
,
17
]),
]
]
...
@@ -119,9 +119,9 @@ def test_to_dgl_blocks_homo():
...
@@ -119,9 +119,9 @@ def test_to_dgl_blocks_homo():
subgraphs
.
append
(
subgraphs
.
append
(
gb
.
SampledSubgraphImpl
(
gb
.
SampledSubgraphImpl
(
node_pairs
=
node_pairs
[
i
],
node_pairs
=
node_pairs
[
i
],
reverse
_column_node_ids
=
reverse
_column_node_ids
[
i
],
original
_column_node_ids
=
original
_column_node_ids
[
i
],
reverse
_row_node_ids
=
reverse
_row_node_ids
[
i
],
original
_row_node_ids
=
original
_row_node_ids
[
i
],
reverse
_edge_ids
=
reverse
_edge_ids
[
i
],
original
_edge_ids
=
original
_edge_ids
[
i
],
)
)
)
)
blocks
=
gb
.
MiniBatch
(
blocks
=
gb
.
MiniBatch
(
...
@@ -133,9 +133,9 @@ def test_to_dgl_blocks_homo():
...
@@ -133,9 +133,9 @@ def test_to_dgl_blocks_homo():
for
i
,
block
in
enumerate
(
blocks
):
for
i
,
block
in
enumerate
(
blocks
):
assert
torch
.
equal
(
block
.
edges
()[
0
],
node_pairs
[
i
][
0
])
assert
torch
.
equal
(
block
.
edges
()[
0
],
node_pairs
[
i
][
0
])
assert
torch
.
equal
(
block
.
edges
()[
1
],
node_pairs
[
i
][
1
])
assert
torch
.
equal
(
block
.
edges
()[
1
],
node_pairs
[
i
][
1
])
assert
torch
.
equal
(
block
.
edata
[
dgl
.
EID
],
reverse
_edge_ids
[
i
])
assert
torch
.
equal
(
block
.
edata
[
dgl
.
EID
],
original
_edge_ids
[
i
])
assert
torch
.
equal
(
block
.
edata
[
"x"
],
edge_features
[
i
][
"x"
])
assert
torch
.
equal
(
block
.
edata
[
"x"
],
edge_features
[
i
][
"x"
])
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
dgl
.
NID
],
reverse
_row_node_ids
[
0
])
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
dgl
.
NID
],
original
_row_node_ids
[
0
])
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
"x"
],
node_features
[
"x"
])
assert
torch
.
equal
(
blocks
[
0
].
srcdata
[
"x"
],
node_features
[
"x"
])
...
@@ -150,15 +150,15 @@ def test_representation():
...
@@ -150,15 +150,15 @@ def test_representation():
torch
.
tensor
([
1
,
0
,
0
]),
torch
.
tensor
([
1
,
0
,
0
]),
),
),
]
]
reverse
_column_node_ids
=
[
original
_column_node_ids
=
[
torch
.
tensor
([
10
,
11
,
12
,
13
]),
torch
.
tensor
([
10
,
11
,
12
,
13
]),
torch
.
tensor
([
10
,
11
]),
torch
.
tensor
([
10
,
11
]),
]
]
reverse
_row_node_ids
=
[
original
_row_node_ids
=
[
torch
.
tensor
([
10
,
11
,
12
,
13
]),
torch
.
tensor
([
10
,
11
,
12
,
13
]),
torch
.
tensor
([
10
,
11
,
12
]),
torch
.
tensor
([
10
,
11
,
12
]),
]
]
reverse
_edge_ids
=
[
original
_edge_ids
=
[
torch
.
tensor
([
19
,
20
,
21
,
22
,
25
,
30
]),
torch
.
tensor
([
19
,
20
,
21
,
22
,
25
,
30
]),
torch
.
tensor
([
10
,
15
,
17
]),
torch
.
tensor
([
10
,
15
,
17
]),
]
]
...
@@ -172,9 +172,9 @@ def test_representation():
...
@@ -172,9 +172,9 @@ def test_representation():
subgraphs
.
append
(
subgraphs
.
append
(
gb
.
SampledSubgraphImpl
(
gb
.
SampledSubgraphImpl
(
node_pairs
=
node_pairs
[
i
],
node_pairs
=
node_pairs
[
i
],
reverse
_column_node_ids
=
reverse
_column_node_ids
[
i
],
original
_column_node_ids
=
original
_column_node_ids
[
i
],
reverse
_row_node_ids
=
reverse
_row_node_ids
[
i
],
original
_row_node_ids
=
original
_row_node_ids
[
i
],
reverse
_edge_ids
=
reverse
_edge_ids
[
i
],
original
_edge_ids
=
original
_edge_ids
[
i
],
)
)
)
)
negative_srcs
=
torch
.
tensor
([[
8
],
[
1
],
[
6
]])
negative_srcs
=
torch
.
tensor
([[
8
],
[
1
],
[
6
]])
...
@@ -220,13 +220,13 @@ def test_representation():
...
@@ -220,13 +220,13 @@ def test_representation():
expect_result
=
str
(
expect_result
=
str
(
"""MiniBatch(seed_nodes=None,
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
sampled_subgraphs=[SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
reverse
_column_node_ids=tensor([10, 11, 12, 13]),
original
_column_node_ids=tensor([10, 11, 12, 13]),
reverse
_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
original
_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
reverse
_row_node_ids=tensor([10, 11, 12, 13]),),
original
_row_node_ids=tensor([10, 11, 12, 13]),),
SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2]), tensor([1, 0, 0])),
SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2]), tensor([1, 0, 0])),
reverse
_column_node_ids=tensor([10, 11]),
original
_column_node_ids=tensor([10, 11]),
reverse
_edge_ids=tensor([10, 15, 17]),
original
_edge_ids=tensor([10, 15, 17]),
reverse
_row_node_ids=tensor([10, 11, 12]),)],
original
_row_node_ids=tensor([10, 11, 12]),)],
node_pairs=[(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
node_pairs=[(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
(tensor([0, 1, 2]), tensor([1, 0, 0]))],
(tensor([0, 1, 2]), tensor([1, 0, 0]))],
node_features={'x': tensor([7, 6, 2, 2])},
node_features={'x': tensor([7, 6, 2, 2])},
...
...
tests/python/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
View file @
ed3840fc
...
@@ -26,24 +26,24 @@ def _assert_container_equal(lhs, rhs):
...
@@ -26,24 +26,24 @@ def _assert_container_equal(lhs, rhs):
def
test_exclude_edges_homo
(
reverse_row
,
reverse_column
):
def
test_exclude_edges_homo
(
reverse_row
,
reverse_column
):
node_pairs
=
(
torch
.
tensor
([
0
,
2
,
3
]),
torch
.
tensor
([
1
,
4
,
2
]))
node_pairs
=
(
torch
.
tensor
([
0
,
2
,
3
]),
torch
.
tensor
([
1
,
4
,
2
]))
if
reverse_row
:
if
reverse_row
:
reverse
_row_node_ids
=
torch
.
tensor
([
10
,
15
,
11
,
24
,
9
])
original
_row_node_ids
=
torch
.
tensor
([
10
,
15
,
11
,
24
,
9
])
src_to_exclude
=
torch
.
tensor
([
11
])
src_to_exclude
=
torch
.
tensor
([
11
])
else
:
else
:
reverse
_row_node_ids
=
None
original
_row_node_ids
=
None
src_to_exclude
=
torch
.
tensor
([
2
])
src_to_exclude
=
torch
.
tensor
([
2
])
if
reverse_column
:
if
reverse_column
:
reverse
_column_node_ids
=
torch
.
tensor
([
10
,
15
,
11
,
24
,
9
])
original
_column_node_ids
=
torch
.
tensor
([
10
,
15
,
11
,
24
,
9
])
dst_to_exclude
=
torch
.
tensor
([
9
])
dst_to_exclude
=
torch
.
tensor
([
9
])
else
:
else
:
reverse
_column_node_ids
=
None
original
_column_node_ids
=
None
dst_to_exclude
=
torch
.
tensor
([
4
])
dst_to_exclude
=
torch
.
tensor
([
4
])
reverse
_edge_ids
=
torch
.
Tensor
([
5
,
9
,
10
])
original
_edge_ids
=
torch
.
Tensor
([
5
,
9
,
10
])
subgraph
=
SampledSubgraphImpl
(
subgraph
=
SampledSubgraphImpl
(
node_pairs
,
node_pairs
,
reverse
_column_node_ids
,
original
_column_node_ids
,
reverse
_row_node_ids
,
original
_row_node_ids
,
reverse
_edge_ids
,
original
_edge_ids
,
)
)
edges_to_exclude
=
(
src_to_exclude
,
dst_to_exclude
)
edges_to_exclude
=
(
src_to_exclude
,
dst_to_exclude
)
result
=
subgraph
.
exclude_edges
(
edges_to_exclude
)
result
=
subgraph
.
exclude_edges
(
edges_to_exclude
)
...
@@ -60,10 +60,10 @@ def test_exclude_edges_homo(reverse_row, reverse_column):
...
@@ -60,10 +60,10 @@ def test_exclude_edges_homo(reverse_row, reverse_column):
_assert_container_equal
(
result
.
node_pairs
,
expected_node_pairs
)
_assert_container_equal
(
result
.
node_pairs
,
expected_node_pairs
)
_assert_container_equal
(
_assert_container_equal
(
result
.
reverse
_column_node_ids
,
expected_column_node_ids
result
.
original
_column_node_ids
,
expected_column_node_ids
)
)
_assert_container_equal
(
result
.
reverse
_row_node_ids
,
expected_row_node_ids
)
_assert_container_equal
(
result
.
original
_row_node_ids
,
expected_row_node_ids
)
_assert_container_equal
(
result
.
reverse
_edge_ids
,
expected_edge_ids
)
_assert_container_equal
(
result
.
original
_edge_ids
,
expected_edge_ids
)
@
pytest
.
mark
.
parametrize
(
"reverse_row"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"reverse_row"
,
[
True
,
False
])
...
@@ -76,27 +76,27 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
...
@@ -76,27 +76,27 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
)
)
}
}
if
reverse_row
:
if
reverse_row
:
reverse
_row_node_ids
=
{
original
_row_node_ids
=
{
"A"
:
torch
.
tensor
([
13
,
14
,
15
]),
"A"
:
torch
.
tensor
([
13
,
14
,
15
]),
}
}
src_to_exclude
=
torch
.
tensor
([
15
,
13
])
src_to_exclude
=
torch
.
tensor
([
15
,
13
])
else
:
else
:
reverse
_row_node_ids
=
None
original
_row_node_ids
=
None
src_to_exclude
=
torch
.
tensor
([
2
,
0
])
src_to_exclude
=
torch
.
tensor
([
2
,
0
])
if
reverse_column
:
if
reverse_column
:
reverse
_column_node_ids
=
{
original
_column_node_ids
=
{
"B"
:
torch
.
tensor
([
10
,
11
,
12
]),
"B"
:
torch
.
tensor
([
10
,
11
,
12
]),
}
}
dst_to_exclude
=
torch
.
tensor
([
10
,
12
])
dst_to_exclude
=
torch
.
tensor
([
10
,
12
])
else
:
else
:
reverse
_column_node_ids
=
None
original
_column_node_ids
=
None
dst_to_exclude
=
torch
.
tensor
([
0
,
2
])
dst_to_exclude
=
torch
.
tensor
([
0
,
2
])
reverse
_edge_ids
=
{
"A:relation:B"
:
torch
.
tensor
([
19
,
20
,
21
])}
original
_edge_ids
=
{
"A:relation:B"
:
torch
.
tensor
([
19
,
20
,
21
])}
subgraph
=
SampledSubgraphImpl
(
subgraph
=
SampledSubgraphImpl
(
node_pairs
=
node_pairs
,
node_pairs
=
node_pairs
,
reverse
_column_node_ids
=
reverse
_column_node_ids
,
original
_column_node_ids
=
original
_column_node_ids
,
reverse
_row_node_ids
=
reverse
_row_node_ids
,
original
_row_node_ids
=
original
_row_node_ids
,
reverse
_edge_ids
=
reverse
_edge_ids
,
original
_edge_ids
=
original
_edge_ids
,
)
)
edges_to_exclude
=
{
edges_to_exclude
=
{
...
@@ -128,7 +128,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
...
@@ -128,7 +128,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
_assert_container_equal
(
result
.
node_pairs
,
expected_node_pairs
)
_assert_container_equal
(
result
.
node_pairs
,
expected_node_pairs
)
_assert_container_equal
(
_assert_container_equal
(
result
.
reverse
_column_node_ids
,
expected_column_node_ids
result
.
original
_column_node_ids
,
expected_column_node_ids
)
)
_assert_container_equal
(
result
.
reverse
_row_node_ids
,
expected_row_node_ids
)
_assert_container_equal
(
result
.
original
_row_node_ids
,
expected_row_node_ids
)
_assert_container_equal
(
result
.
reverse
_edge_ids
,
expected_edge_ids
)
_assert_container_equal
(
result
.
original
_edge_ids
,
expected_edge_ids
)
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
ed3840fc
...
@@ -65,7 +65,7 @@ def test_FeatureFetcher_with_edges_homo():
...
@@ -65,7 +65,7 @@ def test_FeatureFetcher_with_edges_homo():
subgraphs
.
append
(
subgraphs
.
append
(
gb
.
SampledSubgraphImpl
(
gb
.
SampledSubgraphImpl
(
node_pairs
=
(
torch
.
tensor
([]),
torch
.
tensor
([])),
node_pairs
=
(
torch
.
tensor
([]),
torch
.
tensor
([])),
reverse
_edge_ids
=
torch
.
randint
(
0
,
graph
.
num_edges
,
(
10
,)),
original
_edge_ids
=
torch
.
randint
(
0
,
graph
.
num_edges
,
(
10
,)),
)
)
)
)
data
=
gb
.
MiniBatch
(
input_nodes
=
seeds
,
sampled_subgraphs
=
subgraphs
)
data
=
gb
.
MiniBatch
(
input_nodes
=
seeds
,
sampled_subgraphs
=
subgraphs
)
...
@@ -146,7 +146,7 @@ def test_FeatureFetcher_with_edges_hetero():
...
@@ -146,7 +146,7 @@ def test_FeatureFetcher_with_edges_hetero():
def
add_node_and_edge_ids
(
seeds
):
def
add_node_and_edge_ids
(
seeds
):
subgraphs
=
[]
subgraphs
=
[]
reverse
_edge_ids
=
{
original
_edge_ids
=
{
"n1:e1:n2"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
"n1:e1:n2"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
"n2:e2:n1"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
"n2:e2:n1"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
}
}
...
@@ -154,7 +154,7 @@ def test_FeatureFetcher_with_edges_hetero():
...
@@ -154,7 +154,7 @@ def test_FeatureFetcher_with_edges_hetero():
subgraphs
.
append
(
subgraphs
.
append
(
gb
.
SampledSubgraphImpl
(
gb
.
SampledSubgraphImpl
(
node_pairs
=
(
torch
.
tensor
([]),
torch
.
tensor
([])),
node_pairs
=
(
torch
.
tensor
([]),
torch
.
tensor
([])),
reverse
_edge_ids
=
reverse
_edge_ids
,
original
_edge_ids
=
original
_edge_ids
,
)
)
)
)
data
=
gb
.
MiniBatch
(
input_nodes
=
seeds
,
sampled_subgraphs
=
subgraphs
)
data
=
gb
.
MiniBatch
(
input_nodes
=
seeds
,
sampled_subgraphs
=
subgraphs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment