"docs/vscode:/vscode.git/clone" did not exist on "dad8751d29328957efb4732a1bbcf38bee7bc184"
Unverified Commit 1a649657 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt] Add subgraph binding (#5741)

parent 7438b108
...@@ -6,8 +6,7 @@ ...@@ -6,8 +6,7 @@
#ifndef GRAPHBOLT_CSC_SAMPLING_GRAPH_H_ #ifndef GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
#define GRAPHBOLT_CSC_SAMPLING_GRAPH_H_ #define GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
#include <torch/custom_class.h> #include <graphbolt/sampled_subgraph.h>
#include <torch/torch.h>
#include <string> #include <string>
#include <vector> #include <vector>
......
...@@ -22,20 +22,20 @@ namespace sampling { ...@@ -22,20 +22,20 @@ namespace sampling {
* ``` * ```
* auto indptr = torch::tensor({0, 2, 3, 4}, {torch::kInt64}); * auto indptr = torch::tensor({0, 2, 3, 4}, {torch::kInt64});
* auto indices = torch::tensor({55, 101, 3, 3}, {torch::kInt64}); * auto indices = torch::tensor({55, 101, 3, 3}, {torch::kInt64});
* auto reverse_row_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64}); * auto reverse_column_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64});
* *
* SampledSubgraph sampledSubgraph(indptr, indices, reverse_row_node_ids); * SampledSubgraph sampledSubgraph(indptr, indices, reverse_column_node_ids);
* ``` * ```
* *
* The `reverse_row_node_ids` indicates that nodes `[3, 3, 101]` in the * The `reverse_column_node_ids` indicates that nodes `[3, 3, 101]` in the
* original graph are mapped to `[0, 1, 2]` in this subgraph, and because * original graph are mapped to `[0, 1, 2]` in this subgraph, and because
* `reverse_column_node_ids` is `Null`, `{55, 101, 3, 3}` in `indices` is just * `reverse_row_node_ids` is `Null`, `{55, 101, 3, 3}` in `indices` is just
* the original node ids without compaction. * the original node ids without compaction.
* *
* If `reverse_column_node_ids = torch::tensor({55, 101, 3}, {torch::kInt64})`, * If `reverse_row_node_ids = torch::tensor({55, 101, 3}, {torch::kInt64})`,
* it would indicate a different mapping for the column nodes. Note this is * it would indicate a different mapping for the row nodes. Note this is
* inconsistent with row, which is legal, as `3` is mapped to `0` and `1` in the * inconsistent with column, which is legal, as `3` is mapped to `0` and `1` in
* row while `2` in the column. * the column while `2` in the row.
*/ */
struct SampledSubgraph : torch::CustomClassHolder { struct SampledSubgraph : torch::CustomClassHolder {
public: public:
...@@ -44,58 +44,61 @@ struct SampledSubgraph : torch::CustomClassHolder { ...@@ -44,58 +44,61 @@ struct SampledSubgraph : torch::CustomClassHolder {
* *
* @param indptr CSC format index pointer array. * @param indptr CSC format index pointer array.
* @param indices CSC format index array. * @param indices CSC format index array.
* @param reverse_row_node_ids Row's reverse node ids in the original graph. * @param reverse_column_node_ids Row's reverse node ids in the original
* @param reverse_column_node_ids Column's reverse node ids in the original * graph.
* @param reverse_row_node_ids Column's reverse node ids in the original
* graph. * graph.
* @param reverse_edge_ids Reverse edge ids in the original graph. * @param reverse_edge_ids Reverse edge ids in the original graph.
* @param type_per_edge Type id of each edge. * @param type_per_edge Type id of each edge.
*/ */
SampledSubgraph( SampledSubgraph(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor indptr, torch::Tensor indices,
torch::Tensor reverse_row_node_ids, torch::Tensor reverse_column_node_ids,
torch::optional<torch::Tensor> reverse_column_node_ids = torch::nullopt, torch::optional<torch::Tensor> reverse_row_node_ids = torch::nullopt,
torch::optional<torch::Tensor> reverse_edge_ids = torch::nullopt, torch::optional<torch::Tensor> reverse_edge_ids = torch::nullopt,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt) torch::optional<torch::Tensor> type_per_edge = torch::nullopt)
: indptr(indptr), : indptr(indptr),
indices(indices), indices(indices),
reverse_row_node_ids(reverse_row_node_ids),
reverse_column_node_ids(reverse_column_node_ids), reverse_column_node_ids(reverse_column_node_ids),
reverse_row_node_ids(reverse_row_node_ids),
reverse_edge_ids(reverse_edge_ids), reverse_edge_ids(reverse_edge_ids),
type_per_edge(type_per_edge) {} type_per_edge(type_per_edge) {}
SampledSubgraph() = default;
/** /**
* @brief CSC format index pointer array, where the implicit node ids are * @brief CSC format index pointer array, where the implicit node ids are
* already compacted. And the original ids are stored in the * already compacted. And the original ids are stored in the
* `reverse_row_node_ids` field. * `reverse_column_node_ids` field.
*/ */
torch::Tensor indptr; torch::Tensor indptr;
/** /**
* @brief CSC format index array, where the node ids can be compacted ids or * @brief CSC format index array, where the node ids can be compacted ids or
* original ids. If compacted, the original ids are stored in the * original ids. If compacted, the original ids are stored in the
* `reverse_column_node_ids` field. * `reverse_row_node_ids` field.
*/ */
torch::Tensor indices; torch::Tensor indices;
/** /**
* @brief Row's reverse node ids in the original graph. A graph structure can * @brief Column's reverse node ids in the original graph. A graph structure
* be treated as a coordinated row and column pair, and this is the the mapped * can be treated as a coordinated row and column pair, and this is the the
* ids of the row. * mapped ids of the column.
* *
* @note This is required and the mapping relations can be inconsistent with * @note This is required and the mapping relations can be inconsistent with
* column's. * column's.
*/ */
torch::Tensor reverse_row_node_ids; torch::Tensor reverse_column_node_ids;
/** /**
* @brief Column's reverse node ids in the original graph. A graph structure * @brief Row's reverse node ids in the original graph. A graph structure
* can be treated as a coordinated row and column pair, and this is the the * can be treated as a coordinated row and column pair, and this is the the
* mapped ids of the column. * mapped ids of the row.
* *
* @note This is optional and the mapping relations can be inconsistent with * @note This is optional and the mapping relations can be inconsistent with
* row's. * row's.
*/ */
torch::optional<torch::Tensor> reverse_column_node_ids; torch::optional<torch::Tensor> reverse_row_node_ids;
/** /**
* @brief Reverse edge ids in the original graph, the edge with id * @brief Reverse edge ids in the original graph, the edge with id
......
...@@ -11,6 +11,16 @@ namespace graphbolt { ...@@ -11,6 +11,16 @@ namespace graphbolt {
namespace sampling { namespace sampling {
TORCH_LIBRARY(graphbolt, m) { TORCH_LIBRARY(graphbolt, m) {
m.class_<SampledSubgraph>("SampledSubgraph")
.def(torch::init<>())
.def_readwrite("indptr", &SampledSubgraph::indptr)
.def_readwrite("indices", &SampledSubgraph::indices)
.def_readwrite(
"reverse_row_node_ids", &SampledSubgraph::reverse_row_node_ids)
.def_readwrite(
"reverse_column_node_ids", &SampledSubgraph::reverse_column_node_ids)
.def_readwrite("reverse_edge_ids", &SampledSubgraph::reverse_edge_ids)
.def_readwrite("type_per_edge", &SampledSubgraph::type_per_edge);
m.class_<CSCSamplingGraph>("CSCSamplingGraph") m.class_<CSCSamplingGraph>("CSCSamplingGraph")
.def("num_nodes", &CSCSamplingGraph::NumNodes) .def("num_nodes", &CSCSamplingGraph::NumNodes)
.def("num_edges", &CSCSamplingGraph::NumEdges) .def("num_edges", &CSCSamplingGraph::NumEdges)
......
...@@ -36,3 +36,5 @@ def load_graphbolt(): ...@@ -36,3 +36,5 @@ def load_graphbolt():
load_graphbolt() load_graphbolt()
SampledSubgraph = torch.classes.graphbolt.SampledSubgraph
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment