Unverified Commit 1a649657 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt] Add subgraph binding (#5741)

parent 7438b108
......@@ -6,8 +6,7 @@
#ifndef GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
#define GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
#include <torch/custom_class.h>
#include <torch/torch.h>
#include <graphbolt/sampled_subgraph.h>
#include <string>
#include <vector>
......
......@@ -22,20 +22,20 @@ namespace sampling {
* ```
* auto indptr = torch::tensor({0, 2, 3, 4}, {torch::kInt64});
* auto indices = torch::tensor({55, 101, 3, 3}, {torch::kInt64});
* auto reverse_row_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64});
* auto reverse_column_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64});
*
* SampledSubgraph sampledSubgraph(indptr, indices, reverse_row_node_ids);
* SampledSubgraph sampledSubgraph(indptr, indices, reverse_column_node_ids);
* ```
*
* The `reverse_row_node_ids` indicates that nodes `[3, 3, 101]` in the
* The `reverse_column_node_ids` indicates that nodes `[3, 3, 101]` in the
* original graph are mapped to `[0, 1, 2]` in this subgraph, and because
* `reverse_column_node_ids` is `Null`, `{55, 101, 3, 3}` in `indices` is just
* `reverse_row_node_ids` is `Null`, `{55, 101, 3, 3}` in `indices` is just
* the original node ids without compaction.
*
* If `reverse_column_node_ids = torch::tensor({55, 101, 3}, {torch::kInt64})`,
* it would indicate a different mapping for the column nodes. Note this is
* inconsistent with row, which is legal, as `3` is mapped to `0` and `1` in the
* row while `2` in the column.
* If `reverse_row_node_ids = torch::tensor({55, 101, 3}, {torch::kInt64})`,
* it would indicate a different mapping for the row nodes. Note this is
* inconsistent with column, which is legal, as `3` is mapped to `0` and `1` in
* the column while `2` in the row.
*/
struct SampledSubgraph : torch::CustomClassHolder {
public:
......@@ -44,58 +44,61 @@ struct SampledSubgraph : torch::CustomClassHolder {
*
* @param indptr CSC format index pointer array.
* @param indices CSC format index array.
* @param reverse_row_node_ids Row's reverse node ids in the original graph.
* @param reverse_column_node_ids Column's reverse node ids in the original
* @param reverse_column_node_ids Row's reverse node ids in the original
* graph.
* @param reverse_row_node_ids Column's reverse node ids in the original
* graph.
* @param reverse_edge_ids Reverse edge ids in the original graph.
* @param type_per_edge Type id of each edge.
*/
SampledSubgraph(
torch::Tensor indptr, torch::Tensor indices,
torch::Tensor reverse_row_node_ids,
torch::optional<torch::Tensor> reverse_column_node_ids = torch::nullopt,
torch::Tensor reverse_column_node_ids,
torch::optional<torch::Tensor> reverse_row_node_ids = torch::nullopt,
torch::optional<torch::Tensor> reverse_edge_ids = torch::nullopt,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt)
: indptr(indptr),
indices(indices),
reverse_row_node_ids(reverse_row_node_ids),
reverse_column_node_ids(reverse_column_node_ids),
reverse_row_node_ids(reverse_row_node_ids),
reverse_edge_ids(reverse_edge_ids),
type_per_edge(type_per_edge) {}
SampledSubgraph() = default;
/**
* @brief CSC format index pointer array, where the implicit node ids are
* already compacted. And the original ids are stored in the
* `reverse_row_node_ids` field.
* `reverse_column_node_ids` field.
*/
torch::Tensor indptr;
/**
* @brief CSC format index array, where the node ids can be compacted ids or
* original ids. If compacted, the original ids are stored in the
* `reverse_column_node_ids` field.
* `reverse_row_node_ids` field.
*/
torch::Tensor indices;
/**
* @brief Row's reverse node ids in the original graph. A graph structure can
* be treated as a coordinated row and column pair, and this is the the mapped
* ids of the row.
* @brief Column's reverse node ids in the original graph. A graph structure
* can be treated as a coordinated row and column pair, and this is the the
* mapped ids of the column.
*
* @note This is required and the mapping relations can be inconsistent with
* column's.
*/
torch::Tensor reverse_row_node_ids;
torch::Tensor reverse_column_node_ids;
/**
* @brief Column's reverse node ids in the original graph. A graph structure
* @brief Row's reverse node ids in the original graph. A graph structure
* can be treated as a coordinated row and column pair, and this is the the
* mapped ids of the column.
* mapped ids of the row.
*
* @note This is optional and the mapping relations can be inconsistent with
* row's.
*/
torch::optional<torch::Tensor> reverse_column_node_ids;
torch::optional<torch::Tensor> reverse_row_node_ids;
/**
* @brief Reverse edge ids in the original graph, the edge with id
......
......@@ -11,6 +11,16 @@ namespace graphbolt {
namespace sampling {
TORCH_LIBRARY(graphbolt, m) {
m.class_<SampledSubgraph>("SampledSubgraph")
.def(torch::init<>())
.def_readwrite("indptr", &SampledSubgraph::indptr)
.def_readwrite("indices", &SampledSubgraph::indices)
.def_readwrite(
"reverse_row_node_ids", &SampledSubgraph::reverse_row_node_ids)
.def_readwrite(
"reverse_column_node_ids", &SampledSubgraph::reverse_column_node_ids)
.def_readwrite("reverse_edge_ids", &SampledSubgraph::reverse_edge_ids)
.def_readwrite("type_per_edge", &SampledSubgraph::type_per_edge);
m.class_<CSCSamplingGraph>("CSCSamplingGraph")
.def("num_nodes", &CSCSamplingGraph::NumNodes)
.def("num_edges", &CSCSamplingGraph::NumEdges)
......
......@@ -36,3 +36,5 @@ def load_graphbolt():
load_graphbolt()
SampledSubgraph = torch.classes.graphbolt.SampledSubgraph
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment