Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
1a649657
Unverified
Commit
1a649657
authored
May 25, 2023
by
peizhou001
Committed by
GitHub
May 25, 2023
Browse files
[GraphBolt] Add subgraph binding (#5741)
parent
7438b108
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
24 deletions
+38
-24
graphbolt/include/graphbolt/csc_sampling_graph.h
graphbolt/include/graphbolt/csc_sampling_graph.h
+1
-2
graphbolt/include/graphbolt/sampled_subgraph.h
graphbolt/include/graphbolt/sampled_subgraph.h
+25
-22
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+10
-0
python/dgl/graphbolt/__init__.py
python/dgl/graphbolt/__init__.py
+2
-0
No files found.
graphbolt/include/graphbolt/csc_sampling_graph.h
View file @
1a649657
...
...
@@ -6,8 +6,7 @@
#ifndef GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
#define GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
#include <torch/custom_class.h>
#include <torch/torch.h>
#include <graphbolt/sampled_subgraph.h>
#include <string>
#include <vector>
...
...
graphbolt/include/graphbolt/sampled_subgraph.h
View file @
1a649657
...
...
@@ -22,20 +22,20 @@ namespace sampling {
* ```
* auto indptr = torch::tensor({0, 2, 3, 4}, {torch::kInt64});
* auto indices = torch::tensor({55, 101, 3, 3}, {torch::kInt64});
* auto reverse_
row
_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64});
* auto reverse_
column
_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64});
*
* SampledSubgraph sampledSubgraph(indptr, indices, reverse_
row
_node_ids);
* SampledSubgraph sampledSubgraph(indptr, indices, reverse_
column
_node_ids);
* ```
*
* The `reverse_
row
_node_ids` indicates that nodes `[3, 3, 101]` in the
* The `reverse_
column
_node_ids` indicates that nodes `[3, 3, 101]` in the
* original graph are mapped to `[0, 1, 2]` in this subgraph, and because
* `reverse_
column
_node_ids` is `Null`, `{55, 101, 3, 3}` in `indices` is just
* `reverse_
row
_node_ids` is `Null`, `{55, 101, 3, 3}` in `indices` is just
* the original node ids without compaction.
*
* If `reverse_
column
_node_ids = torch::tensor({55, 101, 3}, {torch::kInt64})`,
* it would indicate a different mapping for the
column
nodes. Note this is
* inconsistent with
row
, which is legal, as `3` is mapped to `0` and `1` in
the
*
row
while `2` in the
column
.
* If `reverse_
row
_node_ids = torch::tensor({55, 101, 3}, {torch::kInt64})`,
* it would indicate a different mapping for the
row
nodes. Note this is
* inconsistent with
column
, which is legal, as `3` is mapped to `0` and `1` in
*
the column
while `2` in the
row
.
*/
struct
SampledSubgraph
:
torch
::
CustomClassHolder
{
public:
...
...
@@ -44,58 +44,61 @@ struct SampledSubgraph : torch::CustomClassHolder {
*
* @param indptr CSC format index pointer array.
* @param indices CSC format index array.
* @param reverse_row_node_ids Row's reverse node ids in the original graph.
* @param reverse_column_node_ids Column's reverse node ids in the original
* @param reverse_column_node_ids Row's reverse node ids in the original
* graph.
* @param reverse_row_node_ids Column's reverse node ids in the original
* graph.
* @param reverse_edge_ids Reverse edge ids in the original graph.
* @param type_per_edge Type id of each edge.
*/
SampledSubgraph
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
reverse_
row
_node_ids
,
torch
::
optional
<
torch
::
Tensor
>
reverse_
column
_node_ids
=
torch
::
nullopt
,
torch
::
Tensor
reverse_
column
_node_ids
,
torch
::
optional
<
torch
::
Tensor
>
reverse_
row
_node_ids
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
reverse_edge_ids
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
=
torch
::
nullopt
)
:
indptr
(
indptr
),
indices
(
indices
),
reverse_row_node_ids
(
reverse_row_node_ids
),
reverse_column_node_ids
(
reverse_column_node_ids
),
reverse_row_node_ids
(
reverse_row_node_ids
),
reverse_edge_ids
(
reverse_edge_ids
),
type_per_edge
(
type_per_edge
)
{}
SampledSubgraph
()
=
default
;
/**
* @brief CSC format index pointer array, where the implicit node ids are
* already compacted. And the original ids are stored in the
* `reverse_
row
_node_ids` field.
* `reverse_
column
_node_ids` field.
*/
torch
::
Tensor
indptr
;
/**
* @brief CSC format index array, where the node ids can be compacted ids or
* original ids. If compacted, the original ids are stored in the
* `reverse_
column
_node_ids` field.
* `reverse_
row
_node_ids` field.
*/
torch
::
Tensor
indices
;
/**
* @brief
Row
's reverse node ids in the original graph. A graph structure
can
* be treated as a coordinated row and column pair, and this is the the
mapped
* ids of the
row
.
* @brief
Column
's reverse node ids in the original graph. A graph structure
*
can
be treated as a coordinated row and column pair, and this is the the
*
mapped
ids of the
column
.
*
* @note This is required and the mapping relations can be inconsistent with
* column's.
*/
torch
::
Tensor
reverse_
row
_node_ids
;
torch
::
Tensor
reverse_
column
_node_ids
;
/**
* @brief
Column
's reverse node ids in the original graph. A graph structure
* @brief
Row
's reverse node ids in the original graph. A graph structure
* can be treated as a coordinated row and column pair, and this is the the
* mapped ids of the
column
.
* mapped ids of the
row
.
*
* @note This is optional and the mapping relations can be inconsistent with
* row's.
*/
torch
::
optional
<
torch
::
Tensor
>
reverse_
column
_node_ids
;
torch
::
optional
<
torch
::
Tensor
>
reverse_
row
_node_ids
;
/**
* @brief Reverse edge ids in the original graph, the edge with id
...
...
graphbolt/src/python_binding.cc
View file @
1a649657
...
...
@@ -11,6 +11,16 @@ namespace graphbolt {
namespace
sampling
{
TORCH_LIBRARY
(
graphbolt
,
m
)
{
m
.
class_
<
SampledSubgraph
>
(
"SampledSubgraph"
)
.
def
(
torch
::
init
<>
())
.
def_readwrite
(
"indptr"
,
&
SampledSubgraph
::
indptr
)
.
def_readwrite
(
"indices"
,
&
SampledSubgraph
::
indices
)
.
def_readwrite
(
"reverse_row_node_ids"
,
&
SampledSubgraph
::
reverse_row_node_ids
)
.
def_readwrite
(
"reverse_column_node_ids"
,
&
SampledSubgraph
::
reverse_column_node_ids
)
.
def_readwrite
(
"reverse_edge_ids"
,
&
SampledSubgraph
::
reverse_edge_ids
)
.
def_readwrite
(
"type_per_edge"
,
&
SampledSubgraph
::
type_per_edge
);
m
.
class_
<
CSCSamplingGraph
>
(
"CSCSamplingGraph"
)
.
def
(
"num_nodes"
,
&
CSCSamplingGraph
::
NumNodes
)
.
def
(
"num_edges"
,
&
CSCSamplingGraph
::
NumEdges
)
...
...
python/dgl/graphbolt/__init__.py
View file @
1a649657
...
...
@@ -36,3 +36,5 @@ def load_graphbolt():
load_graphbolt
()
SampledSubgraph
=
torch
.
classes
.
graphbolt
.
SampledSubgraph
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment