Unverified Commit af8451c2 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt] Use string triplet for edge type (#5712)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-25-242.ap-northeast-1.compute.internal>
parent 154b93c7
......@@ -15,36 +15,44 @@
namespace graphbolt {
namespace sampling {
using StringList = std::vector<std::string>;
using NodeTypeList = std::vector<std::string>;
using EdgeTypeList =
std::vector<std::tuple<std::string, std::string, std::string>>;
/**
* @brief Structure representing heterogeneous information about a graph.
*
* Example usage:
*
* Suppose the graph has 6 edges
* node_offset = [0, 2, 4]
* edge_types = [0, 1, 0, 2, 1, 2]
* HeteroInfo info({"A", "B", "C"}, {"X", "Y", "Z"}, node_offset, edge_types);
* Suppose the graph has 3 node types, 3 edge types and 6 edges
* ntypes = {"n1", "n2", "n3"}
* etypes = {("n1", "e1", "n2"), ("n1", "e2", "n3"), ("n2", "e3", "n3")}
* node_type_offset = [0, 2, 4]
* type_per_edge = [0, 1, 0, 2, 1, 2]
* HeteroInfo info(ntypes, etypes, node_type_offset, type_per_edge);
*
* This example creates a `HeteroInfo` object with three node types ("A", "B",
* "C") and three edge types ("X", "Y", "Z"). The `node_offset` tensor
* represents the offset array of node type, the given array indicates that node
* [0, 2) has type "A", [2, 4) has type "B", and [4, 6) has type "C". And the
* `edge_types` tensor represents the type id of each edge.
* This example creates a `HeteroInfo` object with three node types ("n1", "n2",
* "n3") and three edge types (("n1", "e1", "n2"), ("n1", "e2", "n3"), ("n2",
* "e3", "n3")). The `node_type_offset` tensor represents the offset array of
* node type, the given array indicates that node [0, 2) has type "n1", [2, 4)
* has type "n2", and [4, 6) has type "n3". And the `type_per_edge` tensor
* represents the type id of each edge, which is the index in the `etypes`
* tensor.
*/
struct HeteroInfo {
/**
* @brief Constructs a new `HeteroInfo` object.
* @param ntypes List of node types in the graph.
* @param etypes List of edge types in the graph.
* @param ntypes List of node types in the graph, where each node type is a
* string.
* @param etypes List of edge types in the graph, where each edge type is a
* string triplet `(str, str, str)`.
* @param node_type_offset Offset array of node type. It is assumed that nodes
* of same type have consecutive ids.
* @param type_per_edge Type id of each edge, where type id is the
* corresponding index of `edge_types`.
*/
HeteroInfo(
const StringList& ntypes, const StringList& etypes,
const NodeTypeList& ntypes, const EdgeTypeList& etypes,
torch::Tensor& node_type_offset, torch::Tensor& type_per_edge)
: node_types(ntypes),
edge_types(etypes),
......@@ -55,10 +63,10 @@ struct HeteroInfo {
HeteroInfo() = default;
/** @brief List of node types in the graph.*/
StringList node_types;
NodeTypeList node_types;
/** @brief List of edge types in the graph. */
StringList edge_types;
EdgeTypeList edge_types;
/**
* @brief Offset array of node type. The length of it is equal to number of
......@@ -138,7 +146,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
*/
static c10::intrusive_ptr<CSCSamplingGraph> FromCSCWithHeteroInfo(
int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices,
const StringList& ntypes, const StringList& etypes,
const NodeTypeList& ntypes, const EdgeTypeList& etypes,
torch::Tensor node_type_offset, torch::Tensor type_per_edge);
/** @brief Get the number of nodes. */
......@@ -162,10 +170,10 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
}
/** @brief Get the list of node types for a heterogeneous graph. */
inline StringList& NodeTypes() const { return hetero_info_->node_types; }
inline NodeTypeList& NodeTypes() const { return hetero_info_->node_types; }
/** @brief Get the list of edge types for a heterogeneous graph. */
inline const StringList& EdgeTypes() const {
inline const EdgeTypeList& EdgeTypes() const {
return hetero_info_->edge_types;
}
......
......@@ -56,7 +56,7 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC(
c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSCWithHeteroInfo(
int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices,
const StringList& ntypes, const StringList& etypes,
const NodeTypeList& ntypes, const EdgeTypeList& etypes,
torch::Tensor node_type_offset, torch::Tensor type_per_edge) {
TORCH_CHECK(node_type_offset.size(0) > 0);
TORCH_CHECK(node_type_offset.dim() == 1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment