Unverified Commit c83350da authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Remove node and edge types from C graph (#5721)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-16-19.ap-northeast-1.compute.internal>
parent fccf0a31
...@@ -15,92 +15,20 @@ ...@@ -15,92 +15,20 @@
namespace graphbolt { namespace graphbolt {
namespace sampling { namespace sampling {
using NodeTypeList = std::vector<std::string>;
using EdgeTypeList =
std::vector<std::tuple<std::string, std::string, std::string>>;
/** /**
* @brief Structure representing heterogeneous information about a graph. * @brief A sampling oriented csc format graph.
* *
* Example usage: * Example usage:
* *
* Suppose the graph has 3 node types, 3 edge types and 6 edges * Suppose the graph has 3 node types, 3 edge types and 6 edges
* ntypes = {"n1", "n2", "n3"} * auto node_type_offset = {0, 2, 4, 6}
* etypes = {("n1", "e1", "n2"), ("n1", "e2", "n3"), ("n2", "e3", "n3")} * auto type_per_edge = {0, 1, 0, 2, 1, 2}
* node_type_offset = [0, 2, 4] * auto graph = CSCSamplingGraph(..., ..., node_type_offset, type_per_edge)
* type_per_edge = [0, 1, 0, 2, 1, 2]
* HeteroInfo info(ntypes, etypes, node_type_offset, type_per_edge);
* *
* This example creates a `HeteroInfo` object with three node types ("n1", "n2", * The `node_type_offset` tensor represents the offset array of node type, the
* "n3") and three edge types (("n1", "e1", "n2"), ("n1", "e2", "n3"), ("n2", * given array indicates that node [0, 2) has type id 0, [2, 4) has type id 1,
* "e3", "n3")). The `node_type_offset` tensor represents the offset array of * and [4, 6) has type id 2. And the `type_per_edge` tensor represents the type
* node type, the given array indicates that node [0, 2) has type "n1", [2, 4) * id of each edge.
* has type "n2", and [4, 6) has type "n3". And the `type_per_edge` tensor
* represents the type id of each edge, which is the index in the `etypes`
* tensor.
*/
struct HeteroInfo {
/**
* @brief Constructs a new `HeteroInfo` object.
* @param ntypes List of node types in the graph, where each node type is a
* string.
* @param etypes List of edge types in the graph, where each edge type is a
* string triplet `(str, str, str)`.
* @param node_type_offset Offset array of node type. It is assumed that nodes
* of same type have consecutive ids.
* @param type_per_edge Type id of each edge, where type id is the
* corresponding index of `edge_types`.
*/
HeteroInfo(
const NodeTypeList& ntypes, const EdgeTypeList& etypes,
torch::Tensor& node_type_offset, torch::Tensor& type_per_edge)
: node_types(ntypes),
edge_types(etypes),
node_type_offset(node_type_offset),
type_per_edge(type_per_edge) {}
/** @brief Default constructor. */
HeteroInfo() = default;
/** @brief List of node types in the graph.*/
NodeTypeList node_types;
/** @brief List of edge types in the graph. */
EdgeTypeList edge_types;
/**
* @brief Offset array of node type. The length of it is equal to number of
* node types.
*/
torch::Tensor node_type_offset;
/**
* @brief Type id of each edge, where type id is the corresponding index of
* edge_types. The length of it is equal to the number of edges.
*/
torch::Tensor type_per_edge;
/**
* @brief Magic number to indicate Hetero info version in serialize/
* deserialize stages.
*/
static constexpr int64_t kHeteroInfoSerializeMagic = 0xDD2E60F0F6B4A129;
/**
* @brief Load hetero info from stream.
* @param archive Input stream for deserializing.
*/
void Load(torch::serialize::InputArchive& archive);
/**
* @brief Save hetero info to stream.
* @param archive Output stream for serializing.
*/
void Save(torch::serialize::OutputArchive& archive) const;
};
/**
* @brief A sampling oriented csc format graph.
*/ */
class CSCSamplingGraph : public torch::CustomClassHolder { class CSCSamplingGraph : public torch::CustomClassHolder {
public: public:
...@@ -109,34 +37,22 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -109,34 +37,22 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
/** /**
* @brief Constructor for CSC with data. * @brief Constructor for CSC with data.
* @param num_nodes The number of nodes in the graph.
* @param indptr The CSC format index pointer array. * @param indptr The CSC format index pointer array.
* @param indices The CSC format index array. * @param indices The CSC format index array.
* @param hetero_info Heterogeneous graph information, if present. Nullptr * @param node_type_offset A tensor representing the offset of node types, if
* means it is a homogeneous graph. * present.
* @param type_per_edge A tensor representing the type of each edge, if
* present.
*/ */
CSCSamplingGraph( CSCSamplingGraph(
int64_t num_nodes, torch::Tensor& indptr, torch::Tensor& indices, torch::Tensor& indptr, torch::Tensor& indices,
const std::shared_ptr<HeteroInfo>& hetero_info); const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge);
/** /**
* @brief Create a homogeneous CSC graph from tensors of CSC format. * @brief Create a homogeneous CSC graph from tensors of CSC format.
* @param num_nodes The number of nodes in the graph.
* @param indptr Index pointer array of the CSC. * @param indptr Index pointer array of the CSC.
* @param indices Indices array of the CSC. * @param indices Indices array of the CSC.
*
* @return CSCSamplingGraph
*/
static c10::intrusive_ptr<CSCSamplingGraph> FromCSC(
int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices);
/**
* @brief Create a heterogeneous CSC graph from tensors of CSC format.
* @param num_nodes The number of nodes in the graph.
* @param indptr Index pointer array of the CSC.
* @param indices Indices array of the CSC.
* @param ntypes A list of node types, if present.
* @param etypes A list of edge types, if present.
* @param node_type_offset A tensor representing the offset of node types, if * @param node_type_offset A tensor representing the offset of node types, if
* present. * present.
* @param type_per_edge A tensor representing the type of each edge, if * @param type_per_edge A tensor representing the type of each edge, if
...@@ -144,13 +60,13 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -144,13 +60,13 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* *
* @return CSCSamplingGraph * @return CSCSamplingGraph
*/ */
static c10::intrusive_ptr<CSCSamplingGraph> FromCSCWithHeteroInfo( static c10::intrusive_ptr<CSCSamplingGraph> FromCSC(
int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices, torch::Tensor indptr, torch::Tensor indices,
const NodeTypeList& ntypes, const EdgeTypeList& etypes, const torch::optional<torch::Tensor>& node_type_offset,
torch::Tensor node_type_offset, torch::Tensor type_per_edge); const torch::optional<torch::Tensor>& type_per_edge);
/** @brief Get the number of nodes. */ /** @brief Get the number of nodes. */
int64_t NumNodes() const { return num_nodes_; } int64_t NumNodes() const { return indptr_.size(0) - 1; }
/** @brief Get the number of edges. */ /** @brief Get the number of edges. */
int64_t NumEdges() const { return indices_.size(0); } int64_t NumEdges() const { return indices_.size(0); }
...@@ -161,25 +77,14 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -161,25 +77,14 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
/** @brief Get the index tensor. */ /** @brief Get the index tensor. */
const torch::Tensor Indices() const { return indices_; } const torch::Tensor Indices() const { return indices_; }
/** @brief Check if the graph is heterogeneous. */
inline bool IsHeterogeneous() const { return hetero_info_ != nullptr; }
/** @brief Get the node type offset tensor for a heterogeneous graph. */ /** @brief Get the node type offset tensor for a heterogeneous graph. */
inline const torch::Tensor NodeTypeOffset() const { inline const torch::optional<torch::Tensor> NodeTypeOffset() const {
return hetero_info_->node_type_offset; return node_type_offset_;
}
/** @brief Get the list of node types for a heterogeneous graph. */
inline NodeTypeList& NodeTypes() const { return hetero_info_->node_types; }
/** @brief Get the list of edge types for a heterogeneous graph. */
inline const EdgeTypeList& EdgeTypes() const {
return hetero_info_->edge_types;
} }
/** @brief Get the edge type tensor for a heterogeneous graph. */ /** @brief Get the edge type tensor for a heterogeneous graph. */
inline const torch::Tensor TypePerEdge() const { inline const torch::optional<torch::Tensor> TypePerEdge() const {
return hetero_info_->type_per_edge; return type_per_edge_;
} }
/** /**
...@@ -201,14 +106,27 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -201,14 +106,27 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
void Save(torch::serialize::OutputArchive& archive) const; void Save(torch::serialize::OutputArchive& archive) const;
private: private:
/** @brief The number of nodes of the graph. */
int64_t num_nodes_ = 0;
/** @brief CSC format index pointer array. */ /** @brief CSC format index pointer array. */
torch::Tensor indptr_; torch::Tensor indptr_;
/** @brief CSC format index array. */ /** @brief CSC format index array. */
torch::Tensor indices_; torch::Tensor indices_;
/** @brief Heterogeneous graph information, if present. */
std::shared_ptr<HeteroInfo> hetero_info_; /**
* @brief Offset array of node type. The length of it is equal to the number
* of node types + 1. The tensor is in ascending order as nodes of the same
* type have continuous IDs, and larger node IDs are paired with larger node
* type IDs. Its first value is 0 and last value is the number of nodes. And
* nodes with ID between `node_type_offset_[i] ~ node_type_offset_[i+1]` are
* of type id `i`.
*/
torch::optional<torch::Tensor> node_type_offset_;
/**
* @brief Type id of each edge, where type id is the corresponding index of
* edge types. The length of it is equal to the number of edges.
*/
torch::optional<torch::Tensor> type_per_edge_;
}; };
} // namespace sampling } // namespace sampling
......
...@@ -10,67 +10,34 @@ ...@@ -10,67 +10,34 @@
namespace graphbolt { namespace graphbolt {
namespace sampling { namespace sampling {
void HeteroInfo::Load(torch::serialize::InputArchive& archive) {
const int64_t magic_num =
read_from_archive(archive, "HeteroInfo/magic_num").toInt();
TORCH_CHECK(
magic_num == kHeteroInfoSerializeMagic,
"Magic numbers mismatch when loading HeteroInfo.");
node_types = read_from_archive(archive, "HeteroInfo/node_types")
.to<decltype(node_types)>();
edge_types = read_from_archive(archive, "HeteroInfo/edge_types")
.to<decltype(edge_types)>();
node_type_offset =
read_from_archive(archive, "HeteroInfo/node_type_offset").toTensor();
type_per_edge =
read_from_archive(archive, "HeteroInfo/type_per_edge").toTensor();
}
void HeteroInfo::Save(torch::serialize::OutputArchive& archive) const {
archive.write("HeteroInfo/magic_num", kHeteroInfoSerializeMagic);
archive.write("HeteroInfo/node_types", node_types);
archive.write("HeteroInfo/edge_types", edge_types);
archive.write("HeteroInfo/node_type_offset", node_type_offset);
archive.write("HeteroInfo/type_per_edge", type_per_edge);
}
CSCSamplingGraph::CSCSamplingGraph( CSCSamplingGraph::CSCSamplingGraph(
int64_t num_nodes, torch::Tensor& indptr, torch::Tensor& indices, torch::Tensor& indptr, torch::Tensor& indices,
const std::shared_ptr<HeteroInfo>& hetero_info) const torch::optional<torch::Tensor>& node_type_offset,
: num_nodes_(num_nodes), const torch::optional<torch::Tensor>& type_per_edge)
indptr_(indptr), : indptr_(indptr),
indices_(indices), indices_(indices),
hetero_info_(hetero_info) { node_type_offset_(node_type_offset),
TORCH_CHECK(num_nodes >= 0); type_per_edge_(type_per_edge) {
TORCH_CHECK(indptr.dim() == 1); TORCH_CHECK(indptr.dim() == 1);
TORCH_CHECK(indices.dim() == 1); TORCH_CHECK(indices.dim() == 1);
TORCH_CHECK(indptr.size(0) == num_nodes + 1);
TORCH_CHECK(indptr.device() == indices.device()); TORCH_CHECK(indptr.device() == indices.device());
} }
c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC( c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC(
int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices) { torch::Tensor indptr, torch::Tensor indices,
return c10::make_intrusive<CSCSamplingGraph>( const torch::optional<torch::Tensor>& node_type_offset,
num_nodes, indptr, indices, nullptr); const torch::optional<torch::Tensor>& type_per_edge) {
} if (node_type_offset.has_value()) {
auto& offset = node_type_offset.value();
c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSCWithHeteroInfo( TORCH_CHECK(offset.dim() == 1);
int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices, }
const NodeTypeList& ntypes, const EdgeTypeList& etypes, if (type_per_edge.has_value()) {
torch::Tensor node_type_offset, torch::Tensor type_per_edge) { TORCH_CHECK(type_per_edge.value().dim() == 1);
TORCH_CHECK(node_type_offset.size(0) > 0); TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0));
TORCH_CHECK(node_type_offset.dim() == 1); }
TORCH_CHECK(type_per_edge.size(0) > 0);
TORCH_CHECK(type_per_edge.dim() == 1);
TORCH_CHECK(node_type_offset.device() == type_per_edge.device());
TORCH_CHECK(type_per_edge.device() == indices.device());
TORCH_CHECK(!ntypes.empty());
TORCH_CHECK(!etypes.empty());
auto hetero_info = std::make_shared<HeteroInfo>(
ntypes, etypes, node_type_offset, type_per_edge);
return c10::make_intrusive<CSCSamplingGraph>( return c10::make_intrusive<CSCSamplingGraph>(
num_nodes, indptr, indices, hetero_info); indptr, indices, node_type_offset, type_per_edge);
} }
void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
...@@ -79,26 +46,14 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { ...@@ -79,26 +46,14 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
TORCH_CHECK( TORCH_CHECK(
magic_num == kCSCSamplingGraphSerializeMagic, magic_num == kCSCSamplingGraphSerializeMagic,
"Magic numbers mismatch when loading CSCSamplingGraph."); "Magic numbers mismatch when loading CSCSamplingGraph.");
num_nodes_ = read_from_archive(archive, "CSCSamplingGraph/num_nodes").toInt();
indptr_ = read_from_archive(archive, "CSCSamplingGraph/indptr").toTensor(); indptr_ = read_from_archive(archive, "CSCSamplingGraph/indptr").toTensor();
indices_ = read_from_archive(archive, "CSCSamplingGraph/indices").toTensor(); indices_ = read_from_archive(archive, "CSCSamplingGraph/indices").toTensor();
const bool is_heterogeneous =
read_from_archive(archive, "CSCSamplingGraph/is_hetero").toBool();
if (is_heterogeneous) {
hetero_info_ = std::make_shared<HeteroInfo>();
hetero_info_->Load(archive);
}
} }
void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const { void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
archive.write("CSCSamplingGraph/magic_num", kCSCSamplingGraphSerializeMagic); archive.write("CSCSamplingGraph/magic_num", kCSCSamplingGraphSerializeMagic);
archive.write("CSCSamplingGraph/num_nodes", num_nodes_);
archive.write("CSCSamplingGraph/indptr", indptr_); archive.write("CSCSamplingGraph/indptr", indptr_);
archive.write("CSCSamplingGraph/indices", indices_); archive.write("CSCSamplingGraph/indices", indices_);
archive.write("CSCSamplingGraph/is_hetero", IsHeterogeneous());
if (IsHeterogeneous()) {
hetero_info_->Save(archive);
}
} }
} // namespace sampling } // namespace sampling
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment