You need to sign in or sign up before continuing.
Unverified Commit af8451c2 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt] Use string triplet for edge type (#5712)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-25-242.ap-northeast-1.compute.internal>
parent 154b93c7
...@@ -15,36 +15,44 @@ ...@@ -15,36 +15,44 @@
namespace graphbolt { namespace graphbolt {
namespace sampling { namespace sampling {
using StringList = std::vector<std::string>; using NodeTypeList = std::vector<std::string>;
using EdgeTypeList =
std::vector<std::tuple<std::string, std::string, std::string>>;
/** /**
* @brief Structure representing heterogeneous information about a graph. * @brief Structure representing heterogeneous information about a graph.
* *
* Example usage: * Example usage:
* *
* Suppose the graph has 6 edges * Suppose the graph has 3 node types, 3 edge types and 6 edges
* node_offset = [0, 2, 4] * ntypes = {"n1", "n2", "n3"}
* edge_types = [0, 1, 0, 2, 1, 2] * etypes = {("n1", "e1", "n2"), ("n1", "e2", "n3"), ("n2", "e3", "n3")}
* HeteroInfo info({"A", "B", "C"}, {"X", "Y", "Z"}, node_offset, edge_types); * node_type_offset = [0, 2, 4]
* type_per_edge = [0, 1, 0, 2, 1, 2]
* HeteroInfo info(ntypes, etypes, node_type_offset, type_per_edge);
* *
* This example creates a `HeteroInfo` object with three node types ("A", "B", * This example creates a `HeteroInfo` object with three node types ("n1", "n2",
* "C") and three edge types ("X", "Y", "Z"). The `node_offset` tensor * "n3") and three edge types (("n1", "e1", "n2"), ("n1", "e2", "n3"), ("n2",
* represents the offset array of node type, the given array indicates that node * "e3", "n3")). The `node_type_offset` tensor represents the offset array of
* [0, 2) has type "A", [2, 4) has type "B", and [4, 6) has type "C". And the * node type, the given array indicates that node [0, 2) has type "n1", [2, 4)
* `edge_types` tensor represents the type id of each edge. * has type "n2", and [4, 6) has type "n3". And the `type_per_edge` tensor
* represents the type id of each edge, which is the index in the `etypes`
* tensor.
*/ */
struct HeteroInfo { struct HeteroInfo {
/** /**
* @brief Constructs a new `HeteroInfo` object. * @brief Constructs a new `HeteroInfo` object.
* @param ntypes List of node types in the graph. * @param ntypes List of node types in the graph, where each node type is a
* @param etypes List of edge types in the graph. * string.
* @param etypes List of edge types in the graph, where each edge type is a
* string triplet `(str, str, str)`.
* @param node_type_offset Offset array of node type. It is assumed that nodes * @param node_type_offset Offset array of node type. It is assumed that nodes
* of same type have consecutive ids. * of same type have consecutive ids.
* @param type_per_edge Type id of each edge, where type id is the * @param type_per_edge Type id of each edge, where type id is the
* corresponding index of `edge_types`. * corresponding index of `edge_types`.
*/ */
HeteroInfo( HeteroInfo(
const StringList& ntypes, const StringList& etypes, const NodeTypeList& ntypes, const EdgeTypeList& etypes,
torch::Tensor& node_type_offset, torch::Tensor& type_per_edge) torch::Tensor& node_type_offset, torch::Tensor& type_per_edge)
: node_types(ntypes), : node_types(ntypes),
edge_types(etypes), edge_types(etypes),
...@@ -55,10 +63,10 @@ struct HeteroInfo { ...@@ -55,10 +63,10 @@ struct HeteroInfo {
HeteroInfo() = default; HeteroInfo() = default;
/** @brief List of node types in the graph.*/ /** @brief List of node types in the graph.*/
StringList node_types; NodeTypeList node_types;
/** @brief List of edge types in the graph. */ /** @brief List of edge types in the graph. */
StringList edge_types; EdgeTypeList edge_types;
/** /**
* @brief Offset array of node type. The length of it is equal to number of * @brief Offset array of node type. The length of it is equal to number of
...@@ -138,7 +146,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -138,7 +146,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
*/ */
static c10::intrusive_ptr<CSCSamplingGraph> FromCSCWithHeteroInfo( static c10::intrusive_ptr<CSCSamplingGraph> FromCSCWithHeteroInfo(
int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices, int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices,
const StringList& ntypes, const StringList& etypes, const NodeTypeList& ntypes, const EdgeTypeList& etypes,
torch::Tensor node_type_offset, torch::Tensor type_per_edge); torch::Tensor node_type_offset, torch::Tensor type_per_edge);
/** @brief Get the number of nodes. */ /** @brief Get the number of nodes. */
...@@ -162,10 +170,10 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -162,10 +170,10 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
} }
/** @brief Get the list of node types for a heterogeneous graph. */ /** @brief Get the list of node types for a heterogeneous graph. */
inline StringList& NodeTypes() const { return hetero_info_->node_types; } inline NodeTypeList& NodeTypes() const { return hetero_info_->node_types; }
/** @brief Get the list of edge types for a heterogeneous graph. */ /** @brief Get the list of edge types for a heterogeneous graph. */
inline const StringList& EdgeTypes() const { inline const EdgeTypeList& EdgeTypes() const {
return hetero_info_->edge_types; return hetero_info_->edge_types;
} }
......
...@@ -56,7 +56,7 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC( ...@@ -56,7 +56,7 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC(
c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSCWithHeteroInfo( c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSCWithHeteroInfo(
int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices, int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices,
const StringList& ntypes, const StringList& etypes, const NodeTypeList& ntypes, const EdgeTypeList& etypes,
torch::Tensor node_type_offset, torch::Tensor type_per_edge) { torch::Tensor node_type_offset, torch::Tensor type_per_edge) {
TORCH_CHECK(node_type_offset.size(0) > 0); TORCH_CHECK(node_type_offset.size(0) > 0);
TORCH_CHECK(node_type_offset.dim() == 1); TORCH_CHECK(node_type_offset.dim() == 1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment