"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "f0ec93d0163f7017b1849771b7de605d6d82e520"
Unverified Commit c2134442 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] use torch.load/save instead of load/save_fused_xxx() (#6707)

parent 0348ad3d
...@@ -21,11 +21,16 @@ namespace torch { ...@@ -21,11 +21,16 @@ namespace torch {
/** /**
* @brief Overload input stream operator for FusedCSCSamplingGraph * @brief Overload input stream operator for FusedCSCSamplingGraph
* deserialization. * deserialization. This enables `torch::load()` for FusedCSCSamplingGraph.
*
* @param archive Input stream for deserializing. * @param archive Input stream for deserializing.
* @param graph FusedCSCSamplingGraph. * @param graph FusedCSCSamplingGraph.
* *
* @return archive * @return archive
*
* @code
* auto&& graph = c10::make_intrusive<sampling::FusedCSCSamplingGraph>();
* torch::load(*graph, filename);
*/ */
inline serialize::InputArchive& operator>>( inline serialize::InputArchive& operator>>(
serialize::InputArchive& archive, serialize::InputArchive& archive,
...@@ -33,11 +38,15 @@ inline serialize::InputArchive& operator>>( ...@@ -33,11 +38,15 @@ inline serialize::InputArchive& operator>>(
/** /**
* @brief Overload output stream operator for FusedCSCSamplingGraph * @brief Overload output stream operator for FusedCSCSamplingGraph
* serialization. * serialization. This enables `torch::save()` for FusedCSCSamplingGraph.
* @param archive Output stream for serializing. * @param archive Output stream for serializing.
* @param graph FusedCSCSamplingGraph. * @param graph FusedCSCSamplingGraph.
* *
* @return archive * @return archive
*
* @code
* auto&& graph = c10::make_intrusive<sampling::FusedCSCSamplingGraph>();
* torch::save(*graph, filename);
*/ */
inline serialize::OutputArchive& operator<<( inline serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive, serialize::OutputArchive& archive,
...@@ -47,25 +56,6 @@ inline serialize::OutputArchive& operator<<( ...@@ -47,25 +56,6 @@ inline serialize::OutputArchive& operator<<(
namespace graphbolt { namespace graphbolt {
/**
* @brief Load FusedCSCSamplingGraph from file.
* @param filename File name to read.
*
* @return FusedCSCSamplingGraph.
*/
c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> LoadFusedCSCSamplingGraph(
const std::string& filename);
/**
* @brief Save FusedCSCSamplingGraph to file.
* @param graph FusedCSCSamplingGraph to save.
* @param filename File name to save.
*
*/
void SaveFusedCSCSamplingGraph(
c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> graph,
const std::string& filename);
/** /**
* @brief Read data from archive. * @brief Read data from archive.
* @param archive Input archive. * @param archive Input archive.
......
...@@ -66,8 +66,6 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -66,8 +66,6 @@ TORCH_LIBRARY(graphbolt, m) {
return g; return g;
}); });
m.def("from_fused_csc", &FusedCSCSamplingGraph::FromCSC); m.def("from_fused_csc", &FusedCSCSamplingGraph::FromCSC);
m.def("load_fused_csc_sampling_graph", &LoadFusedCSCSamplingGraph);
m.def("save_fused_csc_sampling_graph", &SaveFusedCSCSamplingGraph);
m.def( m.def(
"load_from_shared_memory", &FusedCSCSamplingGraph::LoadFromSharedMemory); "load_from_shared_memory", &FusedCSCSamplingGraph::LoadFromSharedMemory);
m.def("unique_and_compact", &UniqueAndCompact); m.def("unique_and_compact", &UniqueAndCompact);
......
...@@ -27,19 +27,6 @@ serialize::OutputArchive& operator<<( ...@@ -27,19 +27,6 @@ serialize::OutputArchive& operator<<(
namespace graphbolt { namespace graphbolt {
c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> LoadFusedCSCSamplingGraph(
const std::string& filename) {
auto&& graph = c10::make_intrusive<sampling::FusedCSCSamplingGraph>();
torch::load(*graph, filename);
return graph;
}
void SaveFusedCSCSamplingGraph(
c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> graph,
const std::string& filename) {
torch::save(*graph, filename);
}
torch::IValue read_from_archive( torch::IValue read_from_archive(
torch::serialize::InputArchive& archive, const std::string& key) { torch::serialize::InputArchive& archive, const std::string& key) {
torch::IValue data; torch::IValue data;
......
...@@ -7,6 +7,8 @@ import time ...@@ -7,6 +7,8 @@ import time
import numpy as np import numpy as np
import torch
from .. import backend as F from .. import backend as F
from ..base import DGLError, EID, ETYPE, NID, NTYPE from ..base import DGLError, EID, ETYPE, NID, NTYPE
from ..convert import to_homogeneous from ..convert import to_homogeneous
...@@ -1236,6 +1238,7 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config): ...@@ -1236,6 +1238,7 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
part_config : str part_config : str
The partition configuration JSON file. The partition configuration JSON file.
""" """
# As only this function requires GraphBolt for now, let's import here. # As only this function requires GraphBolt for now, let's import here.
from .. import graphbolt from .. import graphbolt
...@@ -1279,6 +1282,6 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config): ...@@ -1279,6 +1282,6 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
part_meta[f"part-{part_id}"]["part_graph"], part_meta[f"part-{part_id}"]["part_graph"],
) )
csc_graph_path = os.path.join( csc_graph_path = os.path.join(
os.path.dirname(orig_graph_path), "fused_csc_sampling_graph.tar" os.path.dirname(orig_graph_path), "fused_csc_sampling_graph.pt"
) )
graphbolt.save_fused_csc_sampling_graph(csc_graph, csc_graph_path) torch.save(csc_graph, csc_graph_path)
"""CSC format sampling graph.""" """CSC format sampling graph."""
# pylint: disable= invalid-name # pylint: disable= invalid-name
import os
import tarfile
import tempfile
from collections import defaultdict from collections import defaultdict
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
...@@ -27,8 +24,6 @@ __all__ = [ ...@@ -27,8 +24,6 @@ __all__ = [
"FusedCSCSamplingGraph", "FusedCSCSamplingGraph",
"from_fused_csc", "from_fused_csc",
"load_from_shared_memory", "load_from_shared_memory",
"load_fused_csc_sampling_graph",
"save_fused_csc_sampling_graph",
"from_dglgraph", "from_dglgraph",
] ]
...@@ -99,11 +94,11 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -99,11 +94,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
return _csc_sampling_graph_str(self) return _csc_sampling_graph_str(self)
def __init__( def __init__(
self, c_csc_graph: torch.ScriptObject, metadata: Optional[GraphMetadata] self,
c_csc_graph: torch.ScriptObject,
): ):
super().__init__() super().__init__()
self._c_csc_graph = c_csc_graph self._c_csc_graph = c_csc_graph
self._metadata = metadata
@property @property
def total_num_nodes(self) -> int: def total_num_nodes(self) -> int:
...@@ -318,12 +313,16 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -318,12 +313,16 @@ class FusedCSCSamplingGraph(SamplingGraph):
def metadata(self) -> Optional[GraphMetadata]: def metadata(self) -> Optional[GraphMetadata]:
"""Returns the metadata of the graph. """Returns the metadata of the graph.
[TODO][Rui] This API needs to be updated.
Returns Returns
------- -------
GraphMetadata or None GraphMetadata or None
If present, returns the metadata of the graph. If present, returns the metadata of the graph.
""" """
return self._metadata if self.node_type_to_id is None or self.edge_type_to_id is None:
return None
return GraphMetadata(self.node_type_to_id, self.edge_type_to_id)
def in_subgraph( def in_subgraph(
self, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]] self, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]]
...@@ -884,7 +883,6 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -884,7 +883,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
""" """
return FusedCSCSamplingGraph( return FusedCSCSamplingGraph(
self._c_csc_graph.copy_to_shared_memory(shared_memory_name), self._c_csc_graph.copy_to_shared_memory(shared_memory_name),
self._metadata,
) )
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
...@@ -975,13 +973,11 @@ def from_fused_csc( ...@@ -975,13 +973,11 @@ def from_fused_csc(
edge_type_to_id, edge_type_to_id,
edge_attributes, edge_attributes,
), ),
metadata,
) )
def load_from_shared_memory( def load_from_shared_memory(
shared_memory_name: str, shared_memory_name: str,
metadata: Optional[GraphMetadata] = None,
) -> FusedCSCSamplingGraph: ) -> FusedCSCSamplingGraph:
"""Load a FusedCSCSamplingGraph object from shared memory. """Load a FusedCSCSamplingGraph object from shared memory.
...@@ -997,7 +993,6 @@ def load_from_shared_memory( ...@@ -997,7 +993,6 @@ def load_from_shared_memory(
""" """
return FusedCSCSamplingGraph( return FusedCSCSamplingGraph(
torch.ops.graphbolt.load_from_shared_memory(shared_memory_name), torch.ops.graphbolt.load_from_shared_memory(shared_memory_name),
metadata,
) )
...@@ -1033,38 +1028,6 @@ def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str: ...@@ -1033,38 +1028,6 @@ def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
return final_str return final_str
def load_fused_csc_sampling_graph(filename):
"""Load FusedCSCSamplingGraph from tar file."""
with tempfile.TemporaryDirectory() as temp_dir:
with tarfile.open(filename, "r") as archive:
archive.extractall(temp_dir)
graph_filename = os.path.join(temp_dir, "fused_csc_sampling_graph.pt")
metadata_filename = os.path.join(temp_dir, "metadata.pt")
return FusedCSCSamplingGraph(
torch.ops.graphbolt.load_fused_csc_sampling_graph(graph_filename),
torch.load(metadata_filename),
)
def save_fused_csc_sampling_graph(graph, filename):
"""Save FusedCSCSamplingGraph to tar file."""
with tempfile.TemporaryDirectory() as temp_dir:
graph_filename = os.path.join(temp_dir, "fused_csc_sampling_graph.pt")
torch.ops.graphbolt.save_fused_csc_sampling_graph(
graph._c_csc_graph, graph_filename
)
metadata_filename = os.path.join(temp_dir, "metadata.pt")
torch.save(graph.metadata, metadata_filename)
with tarfile.open(filename, "w") as archive:
archive.add(
graph_filename, arcname=os.path.basename(graph_filename)
)
archive.add(
metadata_filename, arcname=os.path.basename(metadata_filename)
)
print(f"FusedCSCSamplingGraph has been saved to {filename}.")
def from_dglgraph( def from_dglgraph(
g: DGLGraph, g: DGLGraph,
is_homogeneous: bool = False, is_homogeneous: bool = False,
...@@ -1114,5 +1077,4 @@ def from_dglgraph( ...@@ -1114,5 +1077,4 @@ def from_dglgraph(
edge_type_to_id, edge_type_to_id,
edge_attributes, edge_attributes,
), ),
metadata,
) )
...@@ -17,12 +17,7 @@ from ..dataset import Dataset, Task ...@@ -17,12 +17,7 @@ from ..dataset import Dataset, Task
from ..internal import copy_or_convert_data, read_data from ..internal import copy_or_convert_data, read_data
from ..itemset import ItemSet, ItemSetDict from ..itemset import ItemSet, ItemSetDict
from ..sampling_graph import SamplingGraph from ..sampling_graph import SamplingGraph
from .fused_csc_sampling_graph import ( from .fused_csc_sampling_graph import from_dglgraph, FusedCSCSamplingGraph
from_dglgraph,
FusedCSCSamplingGraph,
load_fused_csc_sampling_graph,
save_fused_csc_sampling_graph,
)
from .ondisk_metadata import ( from .ondisk_metadata import (
OnDiskGraphTopology, OnDiskGraphTopology,
OnDiskMetaData, OnDiskMetaData,
...@@ -147,10 +142,10 @@ def preprocess_ondisk_dataset( ...@@ -147,10 +142,10 @@ def preprocess_ondisk_dataset(
output_config["graph_topology"] = {} output_config["graph_topology"] = {}
output_config["graph_topology"]["type"] = "FusedCSCSamplingGraph" output_config["graph_topology"]["type"] = "FusedCSCSamplingGraph"
output_config["graph_topology"]["path"] = os.path.join( output_config["graph_topology"]["path"] = os.path.join(
processed_dir_prefix, "fused_csc_sampling_graph.tar" processed_dir_prefix, "fused_csc_sampling_graph.pt"
) )
save_fused_csc_sampling_graph( torch.save(
fused_csc_sampling_graph, fused_csc_sampling_graph,
os.path.join( os.path.join(
dataset_dir, dataset_dir,
...@@ -452,7 +447,7 @@ class OnDiskDataset(Dataset): ...@@ -452,7 +447,7 @@ class OnDiskDataset(Dataset):
if graph_topology is None: if graph_topology is None:
return None return None
if graph_topology.type == "FusedCSCSamplingGraph": if graph_topology.type == "FusedCSCSamplingGraph":
return load_fused_csc_sampling_graph(graph_topology.path) return torch.load(graph_topology.path)
raise NotImplementedError( raise NotImplementedError(
f"Graph topology type {graph_topology.type} is not supported." f"Graph topology type {graph_topology.type} is not supported."
) )
......
...@@ -695,9 +695,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo( ...@@ -695,9 +695,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo(
orig_g = dgl.load_graphs( orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl") os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0] )[0][0]
new_g = dgl.graphbolt.load_fused_csc_sampling_graph( new_g = th.load(
os.path.join( os.path.join(
test_dir, f"part{part_id}/fused_csc_sampling_graph.tar" test_dir, f"part{part_id}/fused_csc_sampling_graph.pt"
) )
) )
orig_indptr, orig_indices, _ = orig_g.adj().csc() orig_indptr, orig_indices, _ = orig_g.adj().csc()
...@@ -728,9 +728,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero( ...@@ -728,9 +728,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
orig_g = dgl.load_graphs( orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl") os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0] )[0][0]
new_g = dgl.graphbolt.load_fused_csc_sampling_graph( new_g = th.load(
os.path.join( os.path.join(
test_dir, f"part{part_id}/fused_csc_sampling_graph.tar" test_dir, f"part{part_id}/fused_csc_sampling_graph.pt"
) )
) )
orig_indptr, orig_indices, _ = orig_g.adj().csc() orig_indptr, orig_indices, _ = orig_g.adj().csc()
......
...@@ -297,9 +297,9 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges): ...@@ -297,9 +297,9 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges):
graph = gb.from_fused_csc(csc_indptr, indices) graph = gb.from_fused_csc(csc_indptr, indices)
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
filename = os.path.join(test_dir, "fused_csc_sampling_graph.tar") filename = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
gb.save_fused_csc_sampling_graph(graph, filename) torch.save(graph, filename)
graph2 = gb.load_fused_csc_sampling_graph(filename) graph2 = torch.load(filename)
assert graph.total_num_nodes == graph2.total_num_nodes assert graph.total_num_nodes == graph2.total_num_nodes
assert graph.total_num_edges == graph2.total_num_edges assert graph.total_num_edges == graph2.total_num_edges
...@@ -338,9 +338,9 @@ def test_load_save_hetero_graph( ...@@ -338,9 +338,9 @@ def test_load_save_hetero_graph(
) )
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
filename = os.path.join(test_dir, "fused_csc_sampling_graph.tar") filename = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
gb.save_fused_csc_sampling_graph(graph, filename) torch.save(graph, filename)
graph2 = gb.load_fused_csc_sampling_graph(filename) graph2 = torch.load(filename)
assert graph.total_num_nodes == graph2.total_num_nodes assert graph.total_num_nodes == graph2.total_num_nodes
assert graph.total_num_edges == graph2.total_num_edges assert graph.total_num_edges == graph2.total_num_edges
...@@ -1103,7 +1103,7 @@ def test_homo_graph_on_shared_memory( ...@@ -1103,7 +1103,7 @@ def test_homo_graph_on_shared_memory(
shm_name = "test_homo_g" shm_name = "test_homo_g"
graph1 = graph.copy_to_shared_memory(shm_name) graph1 = graph.copy_to_shared_memory(shm_name)
graph2 = gb.load_from_shared_memory(shm_name, graph.metadata) graph2 = gb.load_from_shared_memory(shm_name)
assert graph1.total_num_nodes == total_num_nodes assert graph1.total_num_nodes == total_num_nodes
assert graph1.total_num_nodes == total_num_nodes assert graph1.total_num_nodes == total_num_nodes
...@@ -1181,7 +1181,7 @@ def test_hetero_graph_on_shared_memory( ...@@ -1181,7 +1181,7 @@ def test_hetero_graph_on_shared_memory(
shm_name = "test_hetero_g" shm_name = "test_hetero_g"
graph1 = graph.copy_to_shared_memory(shm_name) graph1 = graph.copy_to_shared_memory(shm_name)
graph2 = gb.load_from_shared_memory(shm_name, graph.metadata) graph2 = gb.load_from_shared_memory(shm_name)
assert graph1.total_num_nodes == total_num_nodes assert graph1.total_num_nodes == total_num_nodes
assert graph1.total_num_nodes == total_num_nodes assert graph1.total_num_nodes == total_num_nodes
......
...@@ -1008,8 +1008,8 @@ def test_OnDiskDataset_Graph_homogeneous(): ...@@ -1008,8 +1008,8 @@ def test_OnDiskDataset_Graph_homogeneous():
graph = gb.from_fused_csc(csc_indptr, indices) graph = gb.from_fused_csc(csc_indptr, indices)
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.tar") graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
gb.save_fused_csc_sampling_graph(graph, graph_path) torch.save(graph, graph_path)
yaml_content = f""" yaml_content = f"""
graph_topology: graph_topology:
...@@ -1046,8 +1046,8 @@ def test_OnDiskDataset_Graph_heterogeneous(): ...@@ -1046,8 +1046,8 @@ def test_OnDiskDataset_Graph_heterogeneous():
) )
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.tar") graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
gb.save_fused_csc_sampling_graph(graph, graph_path) torch.save(graph, graph_path)
yaml_content = f""" yaml_content = f"""
graph_topology: graph_topology:
...@@ -1119,12 +1119,8 @@ def test_OnDiskDataset_preprocess_homogeneous(): ...@@ -1119,12 +1119,8 @@ def test_OnDiskDataset_preprocess_homogeneous():
assert "graph" not in processed_dataset assert "graph" not in processed_dataset
assert "graph_topology" in processed_dataset assert "graph_topology" in processed_dataset
fused_csc_sampling_graph = ( fused_csc_sampling_graph = torch.load(
gb.fused_csc_sampling_graph.load_fused_csc_sampling_graph( os.path.join(test_dir, processed_dataset["graph_topology"]["path"])
os.path.join(
test_dir, processed_dataset["graph_topology"]["path"]
)
)
) )
assert fused_csc_sampling_graph.total_num_nodes == num_nodes assert fused_csc_sampling_graph.total_num_nodes == num_nodes
assert fused_csc_sampling_graph.total_num_edges == num_edges assert fused_csc_sampling_graph.total_num_edges == num_edges
...@@ -1166,12 +1162,8 @@ def test_OnDiskDataset_preprocess_homogeneous(): ...@@ -1166,12 +1162,8 @@ def test_OnDiskDataset_preprocess_homogeneous():
) )
with open(output_file, "rb") as f: with open(output_file, "rb") as f:
processed_dataset = yaml.load(f, Loader=yaml.Loader) processed_dataset = yaml.load(f, Loader=yaml.Loader)
fused_csc_sampling_graph = ( fused_csc_sampling_graph = torch.load(
gb.fused_csc_sampling_graph.load_fused_csc_sampling_graph( os.path.join(test_dir, processed_dataset["graph_topology"]["path"])
os.path.join(
test_dir, processed_dataset["graph_topology"]["path"]
)
)
) )
assert ( assert (
fused_csc_sampling_graph.edge_attributes is not None fused_csc_sampling_graph.edge_attributes is not None
...@@ -1325,7 +1317,7 @@ def test_OnDiskDataset_preprocess_yaml_content_unix(): ...@@ -1325,7 +1317,7 @@ def test_OnDiskDataset_preprocess_yaml_content_unix():
dataset_name: {dataset_name} dataset_name: {dataset_name}
graph_topology: graph_topology:
type: FusedCSCSamplingGraph type: FusedCSCSamplingGraph
path: preprocessed/fused_csc_sampling_graph.tar path: preprocessed/fused_csc_sampling_graph.pt
feature_data: feature_data:
- domain: node - domain: node
type: null type: null
...@@ -1479,7 +1471,7 @@ def test_OnDiskDataset_preprocess_yaml_content_windows(): ...@@ -1479,7 +1471,7 @@ def test_OnDiskDataset_preprocess_yaml_content_windows():
dataset_name: {dataset_name} dataset_name: {dataset_name}
graph_topology: graph_topology:
type: FusedCSCSamplingGraph type: FusedCSCSamplingGraph
path: preprocessed\\fused_csc_sampling_graph.tar path: preprocessed\\fused_csc_sampling_graph.pt
feature_data: feature_data:
- domain: node - domain: node
type: null type: null
...@@ -1836,8 +1828,8 @@ def test_OnDiskDataset_all_nodes_set_homo(): ...@@ -1836,8 +1828,8 @@ def test_OnDiskDataset_all_nodes_set_homo():
graph = gb.from_fused_csc(csc_indptr, indices) graph = gb.from_fused_csc(csc_indptr, indices)
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.tar") graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
gb.save_fused_csc_sampling_graph(graph, graph_path) torch.save(graph, graph_path)
yaml_content = f""" yaml_content = f"""
graph_topology: graph_topology:
...@@ -1873,8 +1865,8 @@ def test_OnDiskDataset_all_nodes_set_hetero(): ...@@ -1873,8 +1865,8 @@ def test_OnDiskDataset_all_nodes_set_hetero():
) )
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.tar") graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
gb.save_fused_csc_sampling_graph(graph, graph_path) torch.save(graph, graph_path)
yaml_content = f""" yaml_content = f"""
graph_topology: graph_topology:
...@@ -1999,7 +1991,7 @@ def test_BuiltinDataset(): ...@@ -1999,7 +1991,7 @@ def test_BuiltinDataset():
"""Test BuiltinDataset.""" """Test BuiltinDataset."""
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
# Case 1: download from DGL S3 storage. # Case 1: download from DGL S3 storage.
dataset_name = "test-dataset-231204" dataset_name = "test-dataset-231207"
# Add dataset to the builtin dataset list for testing only. # Add dataset to the builtin dataset list for testing only.
gb.BuiltinDataset._all_datasets.append(dataset_name) gb.BuiltinDataset._all_datasets.append(dataset_name)
dataset = gb.BuiltinDataset(name=dataset_name, root=test_dir).load() dataset = gb.BuiltinDataset(name=dataset_name, root=test_dir).load()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment