Unverified Commit 2cb7c69d authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add InSubgraph() for CSCSamplingGraph in c++ level (#5728)

parent 4c54a4c8
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "sampled_subgraph.h"
namespace graphbolt { namespace graphbolt {
namespace sampling { namespace sampling {
...@@ -104,6 +106,15 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -104,6 +106,15 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
*/ */
void Save(torch::serialize::OutputArchive& archive) const; void Save(torch::serialize::OutputArchive& archive) const;
/**
* @brief Return the subgraph induced on the inbound edges of the given nodes.
* @param nodes Type agnostic node IDs to form the subgraph.
*
* @return SampledSubgraph.
*/
c10::intrusive_ptr<SampledSubgraph> InSubgraph(
const torch::Tensor& nodes) const;
private: private:
/** @brief CSC format index pointer array. */ /** @brief CSC format index pointer array. */
torch::Tensor indptr_; torch::Tensor indptr_;
......
...@@ -78,5 +78,42 @@ void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const { ...@@ -78,5 +78,42 @@ void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
} }
} }
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
const torch::Tensor& nodes) const {
using namespace torch::indexing;
const int32_t kDefaultGrainSize = 100;
torch::Tensor indptr = torch::zeros_like(indptr_);
const size_t num_seeds = nodes.size(0);
std::vector<torch::Tensor> indices_arr(num_seeds);
std::vector<torch::Tensor> edge_ids_arr(num_seeds);
std::vector<torch::Tensor> type_per_edge_arr(num_seeds);
torch::parallel_for(
0, num_seeds, kDefaultGrainSize, [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
const int64_t node_id = nodes[i].item<int64_t>();
const int64_t start_idx = indptr_[node_id].item<int64_t>();
const int64_t end_idx = indptr_[node_id + 1].item<int64_t>();
indptr[node_id + 1] = end_idx - start_idx;
indices_arr[i] = indices_.slice(0, start_idx, end_idx);
edge_ids_arr[i] = torch::arange(start_idx, end_idx);
if (type_per_edge_) {
type_per_edge_arr[i] =
type_per_edge_.value().slice(0, start_idx, end_idx);
}
}
});
const auto& nonzero_idx = torch::nonzero(indptr).reshape(-1);
torch::Tensor compact_indptr =
torch::zeros({nonzero_idx.size(0) + 1}, indptr_.dtype());
compact_indptr.index_put_({Slice(1, None)}, indptr.index({nonzero_idx}));
return c10::make_intrusive<SampledSubgraph>(
compact_indptr.cumsum(0), torch::cat(indices_arr), nonzero_idx,
torch::arange(0, NumNodes()), torch::cat(edge_ids_arr),
type_per_edge_
? torch::optional<torch::Tensor>{torch::cat(type_per_edge_arr)}
: torch::nullopt);
}
} // namespace sampling } // namespace sampling
} // namespace graphbolt } // namespace graphbolt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment