Unverified Commit c83350da authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Remove node and edge types from C graph (#5721)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-16-19.ap-northeast-1.compute.internal>
parent fccf0a31
......@@ -15,92 +15,20 @@
namespace graphbolt {
namespace sampling {
using NodeTypeList = std::vector<std::string>;
using EdgeTypeList =
std::vector<std::tuple<std::string, std::string, std::string>>;
/**
* @brief Structure representing heterogeneous information about a graph.
* @brief A sampling oriented csc format graph.
*
* Example usage:
*
* Suppose the graph has 3 node types, 3 edge types and 6 edges
* ntypes = {"n1", "n2", "n3"}
* etypes = {("n1", "e1", "n2"), ("n1", "e2", "n3"), ("n2", "e3", "n3")}
* node_type_offset = [0, 2, 4]
* type_per_edge = [0, 1, 0, 2, 1, 2]
* HeteroInfo info(ntypes, etypes, node_type_offset, type_per_edge);
* auto node_type_offset = {0, 2, 4, 6}
* auto type_per_edge = {0, 1, 0, 2, 1, 2}
* auto graph = CSCSamplingGraph(..., ..., node_type_offset, type_per_edge)
*
* This example creates a `HeteroInfo` object with three node types ("n1", "n2",
* "n3") and three edge types (("n1", "e1", "n2"), ("n1", "e2", "n3"), ("n2",
* "e3", "n3")). The `node_type_offset` tensor represents the offset array of
* node type, the given array indicates that node [0, 2) has type "n1", [2, 4)
* has type "n2", and [4, 6) has type "n3". And the `type_per_edge` tensor
* represents the type id of each edge, which is the index in the `etypes`
* tensor.
*/
struct HeteroInfo {
/**
* @brief Constructs a new `HeteroInfo` object.
* @param ntypes List of node types in the graph, where each node type is a
* string.
* @param etypes List of edge types in the graph, where each edge type is a
* string triplet `(str, str, str)`.
* @param node_type_offset Offset array of node type. It is assumed that nodes
* of same type have consecutive ids.
* @param type_per_edge Type id of each edge, where type id is the
* corresponding index of `edge_types`.
*/
HeteroInfo(
const NodeTypeList& ntypes, const EdgeTypeList& etypes,
torch::Tensor& node_type_offset, torch::Tensor& type_per_edge)
: node_types(ntypes),
edge_types(etypes),
node_type_offset(node_type_offset),
type_per_edge(type_per_edge) {}
/** @brief Default constructor. */
HeteroInfo() = default;
/** @brief List of node types in the graph.*/
NodeTypeList node_types;
/** @brief List of edge types in the graph. */
EdgeTypeList edge_types;
/**
* @brief Offset array of node type. The length of it is equal to number of
* node types.
*/
torch::Tensor node_type_offset;
/**
* @brief Type id of each edge, where type id is the corresponding index of
* edge_types. The length of it is equal to the number of edges.
*/
torch::Tensor type_per_edge;
/**
* @brief Magic number to indicate Hetero info version in serialize/
* deserialize stages.
*/
static constexpr int64_t kHeteroInfoSerializeMagic = 0xDD2E60F0F6B4A129;
/**
* @brief Load hetero info from stream.
* @param archive Input stream for deserializing.
*/
void Load(torch::serialize::InputArchive& archive);
/**
* @brief Save hetero info to stream.
* @param archive Output stream for serializing.
*/
void Save(torch::serialize::OutputArchive& archive) const;
};
/**
* @brief A sampling oriented csc format graph.
* The `node_type_offset` tensor represents the offset array of node type, the
* given array indicates that node [0, 2) has type id 0, [2, 4) has type id 1,
* and [4, 6) has type id 2. And the `type_per_edge` tensor represents the type
* id of each edge.
*/
class CSCSamplingGraph : public torch::CustomClassHolder {
public:
......@@ -109,34 +37,22 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
/**
* @brief Constructor for CSC with data.
* @param num_nodes The number of nodes in the graph.
* @param indptr The CSC format index pointer array.
* @param indices The CSC format index array.
* @param hetero_info Heterogeneous graph information, if present. Nullptr
* means it is a homogeneous graph.
* @param node_type_offset A tensor representing the offset of node types, if
* present.
* @param type_per_edge A tensor representing the type of each edge, if
* present.
*/
CSCSamplingGraph(
int64_t num_nodes, torch::Tensor& indptr, torch::Tensor& indices,
const std::shared_ptr<HeteroInfo>& hetero_info);
torch::Tensor& indptr, torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge);
/**
* @brief Create a homogeneous CSC graph from tensors of CSC format.
* @param num_nodes The number of nodes in the graph.
* @param indptr Index pointer array of the CSC.
* @param indices Indices array of the CSC.
*
* @return CSCSamplingGraph
*/
static c10::intrusive_ptr<CSCSamplingGraph> FromCSC(
int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices);
/**
* @brief Create a heterogeneous CSC graph from tensors of CSC format.
* @param num_nodes The number of nodes in the graph.
* @param indptr Index pointer array of the CSC.
* @param indices Indices array of the CSC.
* @param ntypes A list of node types, if present.
* @param etypes A list of edge types, if present.
* @param node_type_offset A tensor representing the offset of node types, if
* present.
* @param type_per_edge A tensor representing the type of each edge, if
......@@ -144,13 +60,13 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
*
* @return CSCSamplingGraph
*/
static c10::intrusive_ptr<CSCSamplingGraph> FromCSCWithHeteroInfo(
int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices,
const NodeTypeList& ntypes, const EdgeTypeList& etypes,
torch::Tensor node_type_offset, torch::Tensor type_per_edge);
static c10::intrusive_ptr<CSCSamplingGraph> FromCSC(
torch::Tensor indptr, torch::Tensor indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge);
/** @brief Get the number of nodes. */
int64_t NumNodes() const { return num_nodes_; }
int64_t NumNodes() const { return indptr_.size(0) - 1; }
/** @brief Get the number of edges. */
int64_t NumEdges() const { return indices_.size(0); }
......@@ -161,25 +77,14 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
/** @brief Get the index tensor. */
const torch::Tensor Indices() const { return indices_; }
/** @brief Check if the graph is heterogeneous. */
inline bool IsHeterogeneous() const { return hetero_info_ != nullptr; }
/** @brief Get the node type offset tensor for a heterogeneous graph. */
inline const torch::Tensor NodeTypeOffset() const {
return hetero_info_->node_type_offset;
}
/** @brief Get the list of node types for a heterogeneous graph. */
inline NodeTypeList& NodeTypes() const { return hetero_info_->node_types; }
/** @brief Get the list of edge types for a heterogeneous graph. */
inline const EdgeTypeList& EdgeTypes() const {
return hetero_info_->edge_types;
inline const torch::optional<torch::Tensor> NodeTypeOffset() const {
return node_type_offset_;
}
/** @brief Get the edge type tensor for a heterogeneous graph. */
inline const torch::Tensor TypePerEdge() const {
return hetero_info_->type_per_edge;
inline const torch::optional<torch::Tensor> TypePerEdge() const {
return type_per_edge_;
}
/**
......@@ -201,14 +106,27 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
void Save(torch::serialize::OutputArchive& archive) const;
private:
/** @brief The number of nodes of the graph. */
int64_t num_nodes_ = 0;
/** @brief CSC format index pointer array. */
torch::Tensor indptr_;
/** @brief CSC format index array. */
torch::Tensor indices_;
/** @brief Heterogeneous graph information, if present. */
std::shared_ptr<HeteroInfo> hetero_info_;
/**
* @brief Offset array of node type. The length of it is equal to the number
* of node types + 1. The tensor is in ascending order as nodes of the same
* type have continuous IDs, and larger node IDs are paired with larger node
* type IDs. Its first value is 0 and last value is the number of nodes. And
* nodes with ID between `node_type_offset_[i] ~ node_type_offset_[i+1]` are
* of type id `i`.
*/
torch::optional<torch::Tensor> node_type_offset_;
/**
* @brief Type id of each edge, where type id is the corresponding index of
* edge types. The length of it is equal to the number of edges.
*/
torch::optional<torch::Tensor> type_per_edge_;
};
} // namespace sampling
......
......@@ -10,67 +10,34 @@
namespace graphbolt {
namespace sampling {
void HeteroInfo::Load(torch::serialize::InputArchive& archive) {
const int64_t magic_num =
read_from_archive(archive, "HeteroInfo/magic_num").toInt();
TORCH_CHECK(
magic_num == kHeteroInfoSerializeMagic,
"Magic numbers mismatch when loading HeteroInfo.");
node_types = read_from_archive(archive, "HeteroInfo/node_types")
.to<decltype(node_types)>();
edge_types = read_from_archive(archive, "HeteroInfo/edge_types")
.to<decltype(edge_types)>();
node_type_offset =
read_from_archive(archive, "HeteroInfo/node_type_offset").toTensor();
type_per_edge =
read_from_archive(archive, "HeteroInfo/type_per_edge").toTensor();
}
void HeteroInfo::Save(torch::serialize::OutputArchive& archive) const {
archive.write("HeteroInfo/magic_num", kHeteroInfoSerializeMagic);
archive.write("HeteroInfo/node_types", node_types);
archive.write("HeteroInfo/edge_types", edge_types);
archive.write("HeteroInfo/node_type_offset", node_type_offset);
archive.write("HeteroInfo/type_per_edge", type_per_edge);
}
CSCSamplingGraph::CSCSamplingGraph(
int64_t num_nodes, torch::Tensor& indptr, torch::Tensor& indices,
const std::shared_ptr<HeteroInfo>& hetero_info)
: num_nodes_(num_nodes),
indptr_(indptr),
torch::Tensor& indptr, torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge)
: indptr_(indptr),
indices_(indices),
hetero_info_(hetero_info) {
TORCH_CHECK(num_nodes >= 0);
node_type_offset_(node_type_offset),
type_per_edge_(type_per_edge) {
TORCH_CHECK(indptr.dim() == 1);
TORCH_CHECK(indices.dim() == 1);
TORCH_CHECK(indptr.size(0) == num_nodes + 1);
TORCH_CHECK(indptr.device() == indices.device());
}
c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC(
int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices) {
return c10::make_intrusive<CSCSamplingGraph>(
num_nodes, indptr, indices, nullptr);
}
c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSCWithHeteroInfo(
int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices,
const NodeTypeList& ntypes, const EdgeTypeList& etypes,
torch::Tensor node_type_offset, torch::Tensor type_per_edge) {
TORCH_CHECK(node_type_offset.size(0) > 0);
TORCH_CHECK(node_type_offset.dim() == 1);
TORCH_CHECK(type_per_edge.size(0) > 0);
TORCH_CHECK(type_per_edge.dim() == 1);
TORCH_CHECK(node_type_offset.device() == type_per_edge.device());
TORCH_CHECK(type_per_edge.device() == indices.device());
TORCH_CHECK(!ntypes.empty());
TORCH_CHECK(!etypes.empty());
auto hetero_info = std::make_shared<HeteroInfo>(
ntypes, etypes, node_type_offset, type_per_edge);
torch::Tensor indptr, torch::Tensor indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge) {
if (node_type_offset.has_value()) {
auto& offset = node_type_offset.value();
TORCH_CHECK(offset.dim() == 1);
}
if (type_per_edge.has_value()) {
TORCH_CHECK(type_per_edge.value().dim() == 1);
TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0));
}
return c10::make_intrusive<CSCSamplingGraph>(
num_nodes, indptr, indices, hetero_info);
indptr, indices, node_type_offset, type_per_edge);
}
void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
......@@ -79,26 +46,14 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
TORCH_CHECK(
magic_num == kCSCSamplingGraphSerializeMagic,
"Magic numbers mismatch when loading CSCSamplingGraph.");
num_nodes_ = read_from_archive(archive, "CSCSamplingGraph/num_nodes").toInt();
indptr_ = read_from_archive(archive, "CSCSamplingGraph/indptr").toTensor();
indices_ = read_from_archive(archive, "CSCSamplingGraph/indices").toTensor();
const bool is_heterogeneous =
read_from_archive(archive, "CSCSamplingGraph/is_hetero").toBool();
if (is_heterogeneous) {
hetero_info_ = std::make_shared<HeteroInfo>();
hetero_info_->Load(archive);
}
}
void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
archive.write("CSCSamplingGraph/magic_num", kCSCSamplingGraphSerializeMagic);
archive.write("CSCSamplingGraph/num_nodes", num_nodes_);
archive.write("CSCSamplingGraph/indptr", indptr_);
archive.write("CSCSamplingGraph/indices", indices_);
archive.write("CSCSamplingGraph/is_hetero", IsHeterogeneous());
if (IsHeterogeneous()) {
hetero_info_->Save(archive);
}
}
} // namespace sampling
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment