Unverified Commit 3cdc37cc authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add ntype/etype_to_id into graph and save/load() (#6687)

parent 93b39729
......@@ -48,6 +48,8 @@ struct SamplerArgs<SamplerType::LABOR> {
*/
class FusedCSCSamplingGraph : public torch::CustomClassHolder {
public:
using NodeTypeToIDMap = torch::Dict<std::string, int64_t>;
using EdgeTypeToIDMap = torch::Dict<std::string, int64_t>;
using EdgeAttrMap = torch::Dict<std::string, torch::Tensor>;
/** @brief Default constructor. */
FusedCSCSamplingGraph() = default;
......@@ -60,11 +62,19 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* present.
* @param type_per_edge A tensor representing the type of each edge, if
* present.
* @param node_type_to_id A dictionary mapping node type names to type IDs, if
* present.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs, if
* present.
* @param edge_attributes A dictionary of edge attributes, if present.
*
*/
FusedCSCSamplingGraph(
const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<EdgeAttrMap>& edge_attributes);
/**
......@@ -75,6 +85,11 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* present.
* @param type_per_edge A tensor representing the type of each edge, if
* present.
* @param node_type_to_id A dictionary mapping node type names to type IDs, if
* present.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs, if
* present.
* @param edge_attributes A dictionary of edge attributes, if present.
*
* @return FusedCSCSamplingGraph
*/
......@@ -82,6 +97,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<EdgeAttrMap>& edge_attributes);
/** @brief Get the number of nodes. */
......@@ -106,6 +123,22 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
return type_per_edge_;
}
/**
* @brief Get the node type to id map for a heterogeneous graph.
* @note The map is a dictionary mapping node type names to type IDs.
*/
inline const torch::optional<NodeTypeToIDMap> NodeTypeToID() const {
return node_type_to_id_;
}
/**
* @brief Get the edge type to id map for a heterogeneous graph.
* @note The map is a dictionary mapping edge type names to type IDs.
*/
inline const torch::optional<EdgeTypeToIDMap> EdgeTypeToID() const {
return edge_type_to_id_;
}
/** @brief Get the edge attributes dictionary. */
inline const torch::optional<EdgeAttrMap> EdgeAttributes() const {
return edge_attributes_;
......@@ -129,6 +162,24 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
type_per_edge_ = type_per_edge;
}
/**
* @brief Set the node type to id map for a heterogeneous graph.
* @note The map is a dictionary mapping node type names to type IDs.
*/
inline void SetNodeTypeToID(
const torch::optional<NodeTypeToIDMap>& node_type_to_id) {
node_type_to_id_ = node_type_to_id;
}
/**
* @brief Set the edge type to id map for a heterogeneous graph.
* @note The map is a dictionary mapping edge type names to type IDs.
*/
inline void SetEdgeTypeToID(
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id) {
edge_type_to_id_ = edge_type_to_id;
}
/** @brief Set the edge attributes dictionary. */
inline void SetEdgeAttributes(
const torch::optional<EdgeAttrMap>& edge_attributes) {
......@@ -302,6 +353,20 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
*/
torch::optional<torch::Tensor> type_per_edge_;
/**
* @brief A dictionary mapping node type names to type IDs. The length of it
* is equal to the number of node types. The key is the node type name, and
* the value is the corresponding type ID.
*/
torch::optional<NodeTypeToIDMap> node_type_to_id_;
/**
* @brief A dictionary mapping edge type names to type IDs. The length of it
* is equal to the number of edge types. The key is the edge type name, and
* the value is the corresponding type ID.
*/
torch::optional<EdgeTypeToIDMap> edge_type_to_id_;
/**
* @brief A dictionary of edge attributes. Each key represents the attribute's
* name, while the corresponding value holds the attribute's specific value.
......
......@@ -28,11 +28,15 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph(
const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<EdgeAttrMap>& edge_attributes)
: indptr_(indptr),
indices_(indices),
node_type_offset_(node_type_offset),
type_per_edge_(type_per_edge),
node_type_to_id_(node_type_to_id),
edge_type_to_id_(edge_type_to_id),
edge_attributes_(edge_attributes) {
TORCH_CHECK(indptr.dim() == 1);
TORCH_CHECK(indices.dim() == 1);
......@@ -43,14 +47,21 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::FromCSC(
const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<EdgeAttrMap>& edge_attributes) {
if (node_type_offset.has_value()) {
auto& offset = node_type_offset.value();
TORCH_CHECK(offset.dim() == 1);
TORCH_CHECK(node_type_to_id.has_value());
TORCH_CHECK(
offset.size(0) ==
static_cast<int64_t>(node_type_to_id.value().size() + 1));
}
if (type_per_edge.has_value()) {
TORCH_CHECK(type_per_edge.value().dim() == 1);
TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0));
TORCH_CHECK(edge_type_to_id.has_value());
}
if (edge_attributes.has_value()) {
for (const auto& pair : edge_attributes.value()) {
......@@ -58,7 +69,8 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::FromCSC(
}
}
return c10::make_intrusive<FusedCSCSamplingGraph>(
indptr, indices, node_type_offset, type_per_edge, edge_attributes);
indptr, indices, node_type_offset, type_per_edge, node_type_to_id,
edge_type_to_id, edge_attributes);
}
void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
......@@ -84,6 +96,34 @@ void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
.toTensor();
}
if (read_from_archive(archive, "FusedCSCSamplingGraph/has_node_type_to_id")
.toBool()) {
torch::Dict<torch::IValue, torch::IValue> generic_dict =
read_from_archive(archive, "FusedCSCSamplingGraph/node_type_to_id")
.toGenericDict();
NodeTypeToIDMap node_type_to_id;
for (const auto& pair : generic_dict) {
std::string key = pair.key().toStringRef();
int64_t value = pair.value().toInt();
node_type_to_id.insert(std::move(key), value);
}
node_type_to_id_ = std::move(node_type_to_id);
}
if (read_from_archive(archive, "FusedCSCSamplingGraph/has_edge_type_to_id")
.toBool()) {
torch::Dict<torch::IValue, torch::IValue> generic_dict =
read_from_archive(archive, "FusedCSCSamplingGraph/edge_type_to_id")
.toGenericDict();
EdgeTypeToIDMap edge_type_to_id;
for (const auto& pair : generic_dict) {
std::string key = pair.key().toStringRef();
int64_t value = pair.value().toInt();
edge_type_to_id.insert(std::move(key), value);
}
edge_type_to_id_ = std::move(edge_type_to_id);
}
// Optional edge attributes.
torch::IValue has_edge_attributes;
if (archive.try_read(
......@@ -123,6 +163,20 @@ void FusedCSCSamplingGraph::Save(
archive.write(
"FusedCSCSamplingGraph/type_per_edge", type_per_edge_.value());
}
archive.write(
"FusedCSCSamplingGraph/has_node_type_to_id",
node_type_to_id_.has_value());
if (node_type_to_id_) {
archive.write(
"FusedCSCSamplingGraph/node_type_to_id", node_type_to_id_.value());
}
archive.write(
"FusedCSCSamplingGraph/has_edge_type_to_id",
edge_type_to_id_.has_value());
if (edge_type_to_id_) {
archive.write(
"FusedCSCSamplingGraph/edge_type_to_id", edge_type_to_id_.value());
}
archive.write(
"FusedCSCSamplingGraph/has_edge_attributes",
edge_attributes_.has_value());
......@@ -505,7 +559,7 @@ BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) {
auto edge_attributes = helper.ReadTorchTensorDict();
auto graph = c10::make_intrusive<FusedCSCSamplingGraph>(
indptr.value(), indices.value(), node_type_offset, type_per_edge,
edge_attributes);
torch::nullopt, torch::nullopt, edge_attributes);
auto shared_memory = helper.ReleaseSharedMemory();
graph->HoldSharedMemoryObject(
std::move(shared_memory.first), std::move(shared_memory.second));
......
......@@ -1254,7 +1254,12 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
)
# Construct GraphMetadata.
_, _, ntypes, etypes = load_partition_book(part_config, part_id)
metadata = graphbolt.GraphMetadata(ntypes, etypes)
node_type_to_id = {ntype: ntid for ntid, ntype in enumerate(ntypes)}
edge_type_to_id = {
_etype_tuple_to_str(etype): etid
for etid, etype in enumerate(etypes)
}
metadata = graphbolt.GraphMetadata(node_type_to_id, edge_type_to_id)
# Obtain CSC indtpr and indices.
indptr, indices, _ = graph.adj().csc()
# Initalize type per edge.
......@@ -1263,7 +1268,11 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
# Sanity check.
assert len(type_per_edge) == graph.num_edges()
csc_graph = graphbolt.from_fused_csc(
indptr, indices, None, type_per_edge, metadata=metadata
indptr,
indices,
node_type_offset=None,
type_per_edge=type_per_edge,
metadata=metadata,
)
orig_graph_path = os.path.join(
os.path.dirname(part_config),
......
......@@ -456,7 +456,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
node_edge_type[dst_ntype_id].append((etype, etype_id))
# construct subgraphs
for (i, seed) in enumerate(column):
for i, seed in enumerate(column):
l = indptr[i].item()
r = indptr[i + 1].item()
node_type = (
......@@ -465,7 +465,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
).item()
- 1
)
for (etype, etype_id) in node_edge_type[node_type]:
for etype, etype_id in node_edge_type[node_type]:
src_ntype, _, _ = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype]
num_edges = torch.searchsorted(
......@@ -925,12 +925,16 @@ def from_fused_csc(
assert len(metadata.node_type_to_id) + 1 == node_type_offset.size(
0
), "node_type_offset length should be |ntypes| + 1."
node_type_to_id = metadata.node_type_to_id if metadata else None
edge_type_to_id = metadata.edge_type_to_id if metadata else None
return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_fused_csc(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
node_type_to_id,
edge_type_to_id,
edge_attributes,
),
metadata,
......@@ -1046,7 +1050,11 @@ def from_dglgraph(
# Obtain CSC matrix.
indptr, indices, edge_ids = homo_g.adj_tensors("csc")
ntype_count.insert(0, 0)
node_type_offset = torch.cumsum(torch.LongTensor(ntype_count), 0)
node_type_offset = (
None
if is_homogeneous
else torch.cumsum(torch.LongTensor(ntype_count), 0)
)
# Assign edge type according to the order of CSC matrix.
type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids]
......@@ -1056,12 +1064,16 @@ def from_dglgraph(
# Assign edge attributes according to the original eids mapping.
edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids]
node_type_to_id = metadata.node_type_to_id if metadata else None
edge_type_to_id = metadata.edge_type_to_id if metadata else None
return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_fused_csc(
indptr,
indices,
node_type_offset,
type_per_edge,
node_type_to_id,
edge_type_to_id,
edge_attributes,
),
metadata,
......
......@@ -16,6 +16,7 @@ from dgl.distributed import (
partition_graph,
)
from dgl.distributed.graph_partition_book import (
_etype_str_to_tuple,
_etype_tuple_to_str,
DEFAULT_ETYPE,
DEFAULT_NTYPE,
......@@ -707,7 +708,7 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo(
for node_type, type_id in new_g.metadata.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.metadata.edge_type_to_id.items():
assert g.get_etype_id(edge_type) == type_id
assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id
@pytest.mark.parametrize("part_method", ["metis", "random"])
......@@ -738,7 +739,7 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
for node_type, type_id in new_g.metadata.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.metadata.edge_type_to_id.items():
assert g.get_etype_id(edge_type) == type_id
assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id
assert new_g.node_type_offset is None
assert th.equal(orig_g.edata[dgl.ETYPE], new_g.type_per_edge)
......
......@@ -1354,7 +1354,7 @@ def test_from_dglgraph_homogeneous():
assert gb_g.total_num_nodes == dgl_g.num_nodes()
assert gb_g.total_num_edges == dgl_g.num_edges()
assert torch.equal(gb_g.node_type_offset, torch.tensor([0, 1000]))
assert gb_g.node_type_offset is None
assert gb_g.type_per_edge is None
assert gb_g.metadata is None
......
......@@ -1999,8 +1999,8 @@ def test_BuiltinDataset():
"""Test BuiltinDataset."""
with tempfile.TemporaryDirectory() as test_dir:
# Case 1: download from DGL S3 storage.
dataset_name = "test-only"
# Add test-only dataset to the builtin dataset list for testing only.
dataset_name = "test-dataset-231204"
# Add dataset to the builtin dataset list for testing only.
gb.BuiltinDataset._all_datasets.append(dataset_name)
dataset = gb.BuiltinDataset(name=dataset_name, root=test_dir).load()
assert dataset.graph is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment