Unverified Commit 3cdc37cc authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add ntype/etype_to_id into graph and save/load() (#6687)

parent 93b39729
...@@ -48,6 +48,8 @@ struct SamplerArgs<SamplerType::LABOR> { ...@@ -48,6 +48,8 @@ struct SamplerArgs<SamplerType::LABOR> {
*/ */
class FusedCSCSamplingGraph : public torch::CustomClassHolder { class FusedCSCSamplingGraph : public torch::CustomClassHolder {
public: public:
using NodeTypeToIDMap = torch::Dict<std::string, int64_t>;
using EdgeTypeToIDMap = torch::Dict<std::string, int64_t>;
using EdgeAttrMap = torch::Dict<std::string, torch::Tensor>; using EdgeAttrMap = torch::Dict<std::string, torch::Tensor>;
/** @brief Default constructor. */ /** @brief Default constructor. */
FusedCSCSamplingGraph() = default; FusedCSCSamplingGraph() = default;
...@@ -60,11 +62,19 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -60,11 +62,19 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* present. * present.
* @param type_per_edge A tensor representing the type of each edge, if * @param type_per_edge A tensor representing the type of each edge, if
* present. * present.
* @param node_type_to_id A dictionary mapping node type names to type IDs, if
* present.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs, if
* present.
* @param edge_attributes A dictionary of edge attributes, if present.
*
*/ */
FusedCSCSamplingGraph( FusedCSCSamplingGraph(
const torch::Tensor& indptr, const torch::Tensor& indices, const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset, const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge, const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<EdgeAttrMap>& edge_attributes); const torch::optional<EdgeAttrMap>& edge_attributes);
/** /**
...@@ -75,6 +85,11 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -75,6 +85,11 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* present. * present.
* @param type_per_edge A tensor representing the type of each edge, if * @param type_per_edge A tensor representing the type of each edge, if
* present. * present.
* @param node_type_to_id A dictionary mapping node type names to type IDs, if
* present.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs, if
* present.
* @param edge_attributes A dictionary of edge attributes, if present.
* *
* @return FusedCSCSamplingGraph * @return FusedCSCSamplingGraph
*/ */
...@@ -82,6 +97,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -82,6 +97,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
const torch::Tensor& indptr, const torch::Tensor& indices, const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset, const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge, const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<EdgeAttrMap>& edge_attributes); const torch::optional<EdgeAttrMap>& edge_attributes);
/** @brief Get the number of nodes. */ /** @brief Get the number of nodes. */
...@@ -106,6 +123,22 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -106,6 +123,22 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
return type_per_edge_; return type_per_edge_;
} }
/**
* @brief Get the node type to id map for a heterogeneous graph.
* @note The map is a dictionary mapping node type names to type IDs.
*/
inline const torch::optional<NodeTypeToIDMap> NodeTypeToID() const {
return node_type_to_id_;
}
/**
* @brief Get the edge type to id map for a heterogeneous graph.
* @note The map is a dictionary mapping edge type names to type IDs.
*/
inline const torch::optional<EdgeTypeToIDMap> EdgeTypeToID() const {
return edge_type_to_id_;
}
/** @brief Get the edge attributes dictionary. */ /** @brief Get the edge attributes dictionary. */
inline const torch::optional<EdgeAttrMap> EdgeAttributes() const { inline const torch::optional<EdgeAttrMap> EdgeAttributes() const {
return edge_attributes_; return edge_attributes_;
...@@ -129,6 +162,24 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -129,6 +162,24 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
type_per_edge_ = type_per_edge; type_per_edge_ = type_per_edge;
} }
/**
* @brief Set the node type to id map for a heterogeneous graph.
* @note The map is a dictionary mapping node type names to type IDs.
*/
inline void SetNodeTypeToID(
const torch::optional<NodeTypeToIDMap>& node_type_to_id) {
node_type_to_id_ = node_type_to_id;
}
/**
* @brief Set the edge type to id map for a heterogeneous graph.
* @note The map is a dictionary mapping edge type names to type IDs.
*/
inline void SetEdgeTypeToID(
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id) {
edge_type_to_id_ = edge_type_to_id;
}
/** @brief Set the edge attributes dictionary. */ /** @brief Set the edge attributes dictionary. */
inline void SetEdgeAttributes( inline void SetEdgeAttributes(
const torch::optional<EdgeAttrMap>& edge_attributes) { const torch::optional<EdgeAttrMap>& edge_attributes) {
...@@ -302,6 +353,20 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -302,6 +353,20 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
*/ */
torch::optional<torch::Tensor> type_per_edge_; torch::optional<torch::Tensor> type_per_edge_;
/**
* @brief A dictionary mapping node type names to type IDs. The length of it
* is equal to the number of node types. The key is the node type name, and
* the value is the corresponding type ID.
*/
torch::optional<NodeTypeToIDMap> node_type_to_id_;
/**
* @brief A dictionary mapping edge type names to type IDs. The length of it
* is equal to the number of edge types. The key is the edge type name, and
* the value is the corresponding type ID.
*/
torch::optional<EdgeTypeToIDMap> edge_type_to_id_;
/** /**
* @brief A dictionary of edge attributes. Each key represents the attribute's * @brief A dictionary of edge attributes. Each key represents the attribute's
* name, while the corresponding value holds the attribute's specific value. * name, while the corresponding value holds the attribute's specific value.
......
...@@ -28,11 +28,15 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph( ...@@ -28,11 +28,15 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph(
const torch::Tensor& indptr, const torch::Tensor& indices, const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset, const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge, const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<EdgeAttrMap>& edge_attributes) const torch::optional<EdgeAttrMap>& edge_attributes)
: indptr_(indptr), : indptr_(indptr),
indices_(indices), indices_(indices),
node_type_offset_(node_type_offset), node_type_offset_(node_type_offset),
type_per_edge_(type_per_edge), type_per_edge_(type_per_edge),
node_type_to_id_(node_type_to_id),
edge_type_to_id_(edge_type_to_id),
edge_attributes_(edge_attributes) { edge_attributes_(edge_attributes) {
TORCH_CHECK(indptr.dim() == 1); TORCH_CHECK(indptr.dim() == 1);
TORCH_CHECK(indices.dim() == 1); TORCH_CHECK(indices.dim() == 1);
...@@ -43,14 +47,21 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::FromCSC( ...@@ -43,14 +47,21 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::FromCSC(
const torch::Tensor& indptr, const torch::Tensor& indices, const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset, const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge, const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<NodeTypeToIDMap>& node_type_to_id,
const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,
const torch::optional<EdgeAttrMap>& edge_attributes) { const torch::optional<EdgeAttrMap>& edge_attributes) {
if (node_type_offset.has_value()) { if (node_type_offset.has_value()) {
auto& offset = node_type_offset.value(); auto& offset = node_type_offset.value();
TORCH_CHECK(offset.dim() == 1); TORCH_CHECK(offset.dim() == 1);
TORCH_CHECK(node_type_to_id.has_value());
TORCH_CHECK(
offset.size(0) ==
static_cast<int64_t>(node_type_to_id.value().size() + 1));
} }
if (type_per_edge.has_value()) { if (type_per_edge.has_value()) {
TORCH_CHECK(type_per_edge.value().dim() == 1); TORCH_CHECK(type_per_edge.value().dim() == 1);
TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0)); TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0));
TORCH_CHECK(edge_type_to_id.has_value());
} }
if (edge_attributes.has_value()) { if (edge_attributes.has_value()) {
for (const auto& pair : edge_attributes.value()) { for (const auto& pair : edge_attributes.value()) {
...@@ -58,7 +69,8 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::FromCSC( ...@@ -58,7 +69,8 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::FromCSC(
} }
} }
return c10::make_intrusive<FusedCSCSamplingGraph>( return c10::make_intrusive<FusedCSCSamplingGraph>(
indptr, indices, node_type_offset, type_per_edge, edge_attributes); indptr, indices, node_type_offset, type_per_edge, node_type_to_id,
edge_type_to_id, edge_attributes);
} }
void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
...@@ -84,6 +96,34 @@ void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { ...@@ -84,6 +96,34 @@ void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
.toTensor(); .toTensor();
} }
if (read_from_archive(archive, "FusedCSCSamplingGraph/has_node_type_to_id")
.toBool()) {
torch::Dict<torch::IValue, torch::IValue> generic_dict =
read_from_archive(archive, "FusedCSCSamplingGraph/node_type_to_id")
.toGenericDict();
NodeTypeToIDMap node_type_to_id;
for (const auto& pair : generic_dict) {
std::string key = pair.key().toStringRef();
int64_t value = pair.value().toInt();
node_type_to_id.insert(std::move(key), value);
}
node_type_to_id_ = std::move(node_type_to_id);
}
if (read_from_archive(archive, "FusedCSCSamplingGraph/has_edge_type_to_id")
.toBool()) {
torch::Dict<torch::IValue, torch::IValue> generic_dict =
read_from_archive(archive, "FusedCSCSamplingGraph/edge_type_to_id")
.toGenericDict();
EdgeTypeToIDMap edge_type_to_id;
for (const auto& pair : generic_dict) {
std::string key = pair.key().toStringRef();
int64_t value = pair.value().toInt();
edge_type_to_id.insert(std::move(key), value);
}
edge_type_to_id_ = std::move(edge_type_to_id);
}
// Optional edge attributes. // Optional edge attributes.
torch::IValue has_edge_attributes; torch::IValue has_edge_attributes;
if (archive.try_read( if (archive.try_read(
...@@ -123,6 +163,20 @@ void FusedCSCSamplingGraph::Save( ...@@ -123,6 +163,20 @@ void FusedCSCSamplingGraph::Save(
archive.write( archive.write(
"FusedCSCSamplingGraph/type_per_edge", type_per_edge_.value()); "FusedCSCSamplingGraph/type_per_edge", type_per_edge_.value());
} }
archive.write(
"FusedCSCSamplingGraph/has_node_type_to_id",
node_type_to_id_.has_value());
if (node_type_to_id_) {
archive.write(
"FusedCSCSamplingGraph/node_type_to_id", node_type_to_id_.value());
}
archive.write(
"FusedCSCSamplingGraph/has_edge_type_to_id",
edge_type_to_id_.has_value());
if (edge_type_to_id_) {
archive.write(
"FusedCSCSamplingGraph/edge_type_to_id", edge_type_to_id_.value());
}
archive.write( archive.write(
"FusedCSCSamplingGraph/has_edge_attributes", "FusedCSCSamplingGraph/has_edge_attributes",
edge_attributes_.has_value()); edge_attributes_.has_value());
...@@ -505,7 +559,7 @@ BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) { ...@@ -505,7 +559,7 @@ BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) {
auto edge_attributes = helper.ReadTorchTensorDict(); auto edge_attributes = helper.ReadTorchTensorDict();
auto graph = c10::make_intrusive<FusedCSCSamplingGraph>( auto graph = c10::make_intrusive<FusedCSCSamplingGraph>(
indptr.value(), indices.value(), node_type_offset, type_per_edge, indptr.value(), indices.value(), node_type_offset, type_per_edge,
edge_attributes); torch::nullopt, torch::nullopt, edge_attributes);
auto shared_memory = helper.ReleaseSharedMemory(); auto shared_memory = helper.ReleaseSharedMemory();
graph->HoldSharedMemoryObject( graph->HoldSharedMemoryObject(
std::move(shared_memory.first), std::move(shared_memory.second)); std::move(shared_memory.first), std::move(shared_memory.second));
......
...@@ -1254,7 +1254,12 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config): ...@@ -1254,7 +1254,12 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
) )
# Construct GraphMetadata. # Construct GraphMetadata.
_, _, ntypes, etypes = load_partition_book(part_config, part_id) _, _, ntypes, etypes = load_partition_book(part_config, part_id)
metadata = graphbolt.GraphMetadata(ntypes, etypes) node_type_to_id = {ntype: ntid for ntid, ntype in enumerate(ntypes)}
edge_type_to_id = {
_etype_tuple_to_str(etype): etid
for etid, etype in enumerate(etypes)
}
metadata = graphbolt.GraphMetadata(node_type_to_id, edge_type_to_id)
# Obtain CSC indtpr and indices. # Obtain CSC indtpr and indices.
indptr, indices, _ = graph.adj().csc() indptr, indices, _ = graph.adj().csc()
# Initalize type per edge. # Initalize type per edge.
...@@ -1263,7 +1268,11 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config): ...@@ -1263,7 +1268,11 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
# Sanity check. # Sanity check.
assert len(type_per_edge) == graph.num_edges() assert len(type_per_edge) == graph.num_edges()
csc_graph = graphbolt.from_fused_csc( csc_graph = graphbolt.from_fused_csc(
indptr, indices, None, type_per_edge, metadata=metadata indptr,
indices,
node_type_offset=None,
type_per_edge=type_per_edge,
metadata=metadata,
) )
orig_graph_path = os.path.join( orig_graph_path = os.path.join(
os.path.dirname(part_config), os.path.dirname(part_config),
......
...@@ -456,7 +456,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -456,7 +456,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype] dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
node_edge_type[dst_ntype_id].append((etype, etype_id)) node_edge_type[dst_ntype_id].append((etype, etype_id))
# construct subgraphs # construct subgraphs
for (i, seed) in enumerate(column): for i, seed in enumerate(column):
l = indptr[i].item() l = indptr[i].item()
r = indptr[i + 1].item() r = indptr[i + 1].item()
node_type = ( node_type = (
...@@ -465,7 +465,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -465,7 +465,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
).item() ).item()
- 1 - 1
) )
for (etype, etype_id) in node_edge_type[node_type]: for etype, etype_id in node_edge_type[node_type]:
src_ntype, _, _ = etype_str_to_tuple(etype) src_ntype, _, _ = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype] src_ntype_id = self.metadata.node_type_to_id[src_ntype]
num_edges = torch.searchsorted( num_edges = torch.searchsorted(
...@@ -925,12 +925,16 @@ def from_fused_csc( ...@@ -925,12 +925,16 @@ def from_fused_csc(
assert len(metadata.node_type_to_id) + 1 == node_type_offset.size( assert len(metadata.node_type_to_id) + 1 == node_type_offset.size(
0 0
), "node_type_offset length should be |ntypes| + 1." ), "node_type_offset length should be |ntypes| + 1."
node_type_to_id = metadata.node_type_to_id if metadata else None
edge_type_to_id = metadata.edge_type_to_id if metadata else None
return FusedCSCSamplingGraph( return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_fused_csc( torch.ops.graphbolt.from_fused_csc(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset, node_type_offset,
type_per_edge, type_per_edge,
node_type_to_id,
edge_type_to_id,
edge_attributes, edge_attributes,
), ),
metadata, metadata,
...@@ -1046,7 +1050,11 @@ def from_dglgraph( ...@@ -1046,7 +1050,11 @@ def from_dglgraph(
# Obtain CSC matrix. # Obtain CSC matrix.
indptr, indices, edge_ids = homo_g.adj_tensors("csc") indptr, indices, edge_ids = homo_g.adj_tensors("csc")
ntype_count.insert(0, 0) ntype_count.insert(0, 0)
node_type_offset = torch.cumsum(torch.LongTensor(ntype_count), 0) node_type_offset = (
None
if is_homogeneous
else torch.cumsum(torch.LongTensor(ntype_count), 0)
)
# Assign edge type according to the order of CSC matrix. # Assign edge type according to the order of CSC matrix.
type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids] type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids]
...@@ -1056,12 +1064,16 @@ def from_dglgraph( ...@@ -1056,12 +1064,16 @@ def from_dglgraph(
# Assign edge attributes according to the original eids mapping. # Assign edge attributes according to the original eids mapping.
edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids] edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids]
node_type_to_id = metadata.node_type_to_id if metadata else None
edge_type_to_id = metadata.edge_type_to_id if metadata else None
return FusedCSCSamplingGraph( return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_fused_csc( torch.ops.graphbolt.from_fused_csc(
indptr, indptr,
indices, indices,
node_type_offset, node_type_offset,
type_per_edge, type_per_edge,
node_type_to_id,
edge_type_to_id,
edge_attributes, edge_attributes,
), ),
metadata, metadata,
......
...@@ -16,6 +16,7 @@ from dgl.distributed import ( ...@@ -16,6 +16,7 @@ from dgl.distributed import (
partition_graph, partition_graph,
) )
from dgl.distributed.graph_partition_book import ( from dgl.distributed.graph_partition_book import (
_etype_str_to_tuple,
_etype_tuple_to_str, _etype_tuple_to_str,
DEFAULT_ETYPE, DEFAULT_ETYPE,
DEFAULT_NTYPE, DEFAULT_NTYPE,
...@@ -707,7 +708,7 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo( ...@@ -707,7 +708,7 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo(
for node_type, type_id in new_g.metadata.node_type_to_id.items(): for node_type, type_id in new_g.metadata.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.metadata.edge_type_to_id.items(): for edge_type, type_id in new_g.metadata.edge_type_to_id.items():
assert g.get_etype_id(edge_type) == type_id assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id
@pytest.mark.parametrize("part_method", ["metis", "random"]) @pytest.mark.parametrize("part_method", ["metis", "random"])
...@@ -738,7 +739,7 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero( ...@@ -738,7 +739,7 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
for node_type, type_id in new_g.metadata.node_type_to_id.items(): for node_type, type_id in new_g.metadata.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.metadata.edge_type_to_id.items(): for edge_type, type_id in new_g.metadata.edge_type_to_id.items():
assert g.get_etype_id(edge_type) == type_id assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id
assert new_g.node_type_offset is None assert new_g.node_type_offset is None
assert th.equal(orig_g.edata[dgl.ETYPE], new_g.type_per_edge) assert th.equal(orig_g.edata[dgl.ETYPE], new_g.type_per_edge)
......
...@@ -1354,7 +1354,7 @@ def test_from_dglgraph_homogeneous(): ...@@ -1354,7 +1354,7 @@ def test_from_dglgraph_homogeneous():
assert gb_g.total_num_nodes == dgl_g.num_nodes() assert gb_g.total_num_nodes == dgl_g.num_nodes()
assert gb_g.total_num_edges == dgl_g.num_edges() assert gb_g.total_num_edges == dgl_g.num_edges()
assert torch.equal(gb_g.node_type_offset, torch.tensor([0, 1000])) assert gb_g.node_type_offset is None
assert gb_g.type_per_edge is None assert gb_g.type_per_edge is None
assert gb_g.metadata is None assert gb_g.metadata is None
......
...@@ -1999,8 +1999,8 @@ def test_BuiltinDataset(): ...@@ -1999,8 +1999,8 @@ def test_BuiltinDataset():
"""Test BuiltinDataset.""" """Test BuiltinDataset."""
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
# Case 1: download from DGL S3 storage. # Case 1: download from DGL S3 storage.
dataset_name = "test-only" dataset_name = "test-dataset-231204"
# Add test-only dataset to the builtin dataset list for testing only. # Add dataset to the builtin dataset list for testing only.
gb.BuiltinDataset._all_datasets.append(dataset_name) gb.BuiltinDataset._all_datasets.append(dataset_name)
dataset = gb.BuiltinDataset(name=dataset_name, root=test_dir).load() dataset = gb.BuiltinDataset(name=dataset_name, root=test_dir).load()
assert dataset.graph is not None assert dataset.graph is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment