"tests/python/pytorch/nn/test_nn.py" did not exist on "290b7c25c25eb3f10372c5423d6f12da1031955e"
Unverified Commit c2134442 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] use torch.load/save instead of load/save_fused_xxx() (#6707)

parent 0348ad3d
......@@ -21,11 +21,16 @@ namespace torch {
/**
* @brief Overload input stream operator for FusedCSCSamplingGraph
* deserialization.
* deserialization. This enables `torch::load()` for FusedCSCSamplingGraph.
*
* @param archive Input stream for deserializing.
* @param graph FusedCSCSamplingGraph.
*
* @return archive
*
* @code
* auto&& graph = c10::make_intrusive<sampling::FusedCSCSamplingGraph>();
* torch::load(*graph, filename);
*/
inline serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
......@@ -33,11 +38,15 @@ inline serialize::InputArchive& operator>>(
/**
* @brief Overload output stream operator for FusedCSCSamplingGraph
* serialization.
* serialization. This enables `torch::save()` for FusedCSCSamplingGraph.
* @param archive Output stream for serializing.
* @param graph FusedCSCSamplingGraph.
*
* @return archive
*
* @code
* auto&& graph = c10::make_intrusive<sampling::FusedCSCSamplingGraph>();
* torch::save(*graph, filename);
*/
inline serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
......@@ -47,25 +56,6 @@ inline serialize::OutputArchive& operator<<(
namespace graphbolt {
/**
* @brief Load FusedCSCSamplingGraph from file.
* @param filename File name to read.
*
* @return FusedCSCSamplingGraph.
*/
c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> LoadFusedCSCSamplingGraph(
const std::string& filename);
/**
* @brief Save FusedCSCSamplingGraph to file.
* @param graph FusedCSCSamplingGraph to save.
* @param filename File name to save.
*
*/
void SaveFusedCSCSamplingGraph(
c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> graph,
const std::string& filename);
/**
* @brief Read data from archive.
* @param archive Input archive.
......
......@@ -66,8 +66,6 @@ TORCH_LIBRARY(graphbolt, m) {
return g;
});
m.def("from_fused_csc", &FusedCSCSamplingGraph::FromCSC);
m.def("load_fused_csc_sampling_graph", &LoadFusedCSCSamplingGraph);
m.def("save_fused_csc_sampling_graph", &SaveFusedCSCSamplingGraph);
m.def(
"load_from_shared_memory", &FusedCSCSamplingGraph::LoadFromSharedMemory);
m.def("unique_and_compact", &UniqueAndCompact);
......
......@@ -27,19 +27,6 @@ serialize::OutputArchive& operator<<(
namespace graphbolt {
c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> LoadFusedCSCSamplingGraph(
const std::string& filename) {
auto&& graph = c10::make_intrusive<sampling::FusedCSCSamplingGraph>();
torch::load(*graph, filename);
return graph;
}
void SaveFusedCSCSamplingGraph(
c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> graph,
const std::string& filename) {
torch::save(*graph, filename);
}
torch::IValue read_from_archive(
torch::serialize::InputArchive& archive, const std::string& key) {
torch::IValue data;
......
......@@ -7,6 +7,8 @@ import time
import numpy as np
import torch
from .. import backend as F
from ..base import DGLError, EID, ETYPE, NID, NTYPE
from ..convert import to_homogeneous
......@@ -1236,6 +1238,7 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
part_config : str
The partition configuration JSON file.
"""
# As only this function requires GraphBolt for now, let's import here.
from .. import graphbolt
......@@ -1279,6 +1282,6 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
part_meta[f"part-{part_id}"]["part_graph"],
)
csc_graph_path = os.path.join(
os.path.dirname(orig_graph_path), "fused_csc_sampling_graph.tar"
os.path.dirname(orig_graph_path), "fused_csc_sampling_graph.pt"
)
graphbolt.save_fused_csc_sampling_graph(csc_graph, csc_graph_path)
torch.save(csc_graph, csc_graph_path)
"""CSC format sampling graph."""
# pylint: disable= invalid-name
import os
import tarfile
import tempfile
from collections import defaultdict
from typing import Dict, Optional, Union
......@@ -27,8 +24,6 @@ __all__ = [
"FusedCSCSamplingGraph",
"from_fused_csc",
"load_from_shared_memory",
"load_fused_csc_sampling_graph",
"save_fused_csc_sampling_graph",
"from_dglgraph",
]
......@@ -99,11 +94,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
return _csc_sampling_graph_str(self)
def __init__(
self, c_csc_graph: torch.ScriptObject, metadata: Optional[GraphMetadata]
self,
c_csc_graph: torch.ScriptObject,
):
super().__init__()
self._c_csc_graph = c_csc_graph
self._metadata = metadata
@property
def total_num_nodes(self) -> int:
......@@ -318,12 +313,16 @@ class FusedCSCSamplingGraph(SamplingGraph):
def metadata(self) -> Optional[GraphMetadata]:
"""Returns the metadata of the graph.
[TODO][Rui] This API needs to be updated.
Returns
-------
GraphMetadata or None
If present, returns the metadata of the graph.
"""
return self._metadata
if self.node_type_to_id is None or self.edge_type_to_id is None:
return None
return GraphMetadata(self.node_type_to_id, self.edge_type_to_id)
def in_subgraph(
self, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]]
......@@ -884,7 +883,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""
return FusedCSCSamplingGraph(
self._c_csc_graph.copy_to_shared_memory(shared_memory_name),
self._metadata,
)
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
......@@ -975,13 +973,11 @@ def from_fused_csc(
edge_type_to_id,
edge_attributes,
),
metadata,
)
def load_from_shared_memory(
shared_memory_name: str,
metadata: Optional[GraphMetadata] = None,
) -> FusedCSCSamplingGraph:
"""Load a FusedCSCSamplingGraph object from shared memory.
......@@ -997,7 +993,6 @@ def load_from_shared_memory(
"""
return FusedCSCSamplingGraph(
torch.ops.graphbolt.load_from_shared_memory(shared_memory_name),
metadata,
)
......@@ -1033,38 +1028,6 @@ def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
return final_str
def load_fused_csc_sampling_graph(filename):
"""Load FusedCSCSamplingGraph from tar file."""
with tempfile.TemporaryDirectory() as temp_dir:
with tarfile.open(filename, "r") as archive:
archive.extractall(temp_dir)
graph_filename = os.path.join(temp_dir, "fused_csc_sampling_graph.pt")
metadata_filename = os.path.join(temp_dir, "metadata.pt")
return FusedCSCSamplingGraph(
torch.ops.graphbolt.load_fused_csc_sampling_graph(graph_filename),
torch.load(metadata_filename),
)
def save_fused_csc_sampling_graph(graph, filename):
"""Save FusedCSCSamplingGraph to tar file."""
with tempfile.TemporaryDirectory() as temp_dir:
graph_filename = os.path.join(temp_dir, "fused_csc_sampling_graph.pt")
torch.ops.graphbolt.save_fused_csc_sampling_graph(
graph._c_csc_graph, graph_filename
)
metadata_filename = os.path.join(temp_dir, "metadata.pt")
torch.save(graph.metadata, metadata_filename)
with tarfile.open(filename, "w") as archive:
archive.add(
graph_filename, arcname=os.path.basename(graph_filename)
)
archive.add(
metadata_filename, arcname=os.path.basename(metadata_filename)
)
print(f"FusedCSCSamplingGraph has been saved to {filename}.")
def from_dglgraph(
g: DGLGraph,
is_homogeneous: bool = False,
......@@ -1114,5 +1077,4 @@ def from_dglgraph(
edge_type_to_id,
edge_attributes,
),
metadata,
)
......@@ -17,12 +17,7 @@ from ..dataset import Dataset, Task
from ..internal import copy_or_convert_data, read_data
from ..itemset import ItemSet, ItemSetDict
from ..sampling_graph import SamplingGraph
from .fused_csc_sampling_graph import (
from_dglgraph,
FusedCSCSamplingGraph,
load_fused_csc_sampling_graph,
save_fused_csc_sampling_graph,
)
from .fused_csc_sampling_graph import from_dglgraph, FusedCSCSamplingGraph
from .ondisk_metadata import (
OnDiskGraphTopology,
OnDiskMetaData,
......@@ -147,10 +142,10 @@ def preprocess_ondisk_dataset(
output_config["graph_topology"] = {}
output_config["graph_topology"]["type"] = "FusedCSCSamplingGraph"
output_config["graph_topology"]["path"] = os.path.join(
processed_dir_prefix, "fused_csc_sampling_graph.tar"
processed_dir_prefix, "fused_csc_sampling_graph.pt"
)
save_fused_csc_sampling_graph(
torch.save(
fused_csc_sampling_graph,
os.path.join(
dataset_dir,
......@@ -452,7 +447,7 @@ class OnDiskDataset(Dataset):
if graph_topology is None:
return None
if graph_topology.type == "FusedCSCSamplingGraph":
return load_fused_csc_sampling_graph(graph_topology.path)
return torch.load(graph_topology.path)
raise NotImplementedError(
f"Graph topology type {graph_topology.type} is not supported."
)
......
......@@ -695,9 +695,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo(
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0]
new_g = dgl.graphbolt.load_fused_csc_sampling_graph(
new_g = th.load(
os.path.join(
test_dir, f"part{part_id}/fused_csc_sampling_graph.tar"
test_dir, f"part{part_id}/fused_csc_sampling_graph.pt"
)
)
orig_indptr, orig_indices, _ = orig_g.adj().csc()
......@@ -728,9 +728,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0]
new_g = dgl.graphbolt.load_fused_csc_sampling_graph(
new_g = th.load(
os.path.join(
test_dir, f"part{part_id}/fused_csc_sampling_graph.tar"
test_dir, f"part{part_id}/fused_csc_sampling_graph.pt"
)
)
orig_indptr, orig_indices, _ = orig_g.adj().csc()
......
......@@ -297,9 +297,9 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges):
graph = gb.from_fused_csc(csc_indptr, indices)
with tempfile.TemporaryDirectory() as test_dir:
filename = os.path.join(test_dir, "fused_csc_sampling_graph.tar")
gb.save_fused_csc_sampling_graph(graph, filename)
graph2 = gb.load_fused_csc_sampling_graph(filename)
filename = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
torch.save(graph, filename)
graph2 = torch.load(filename)
assert graph.total_num_nodes == graph2.total_num_nodes
assert graph.total_num_edges == graph2.total_num_edges
......@@ -338,9 +338,9 @@ def test_load_save_hetero_graph(
)
with tempfile.TemporaryDirectory() as test_dir:
filename = os.path.join(test_dir, "fused_csc_sampling_graph.tar")
gb.save_fused_csc_sampling_graph(graph, filename)
graph2 = gb.load_fused_csc_sampling_graph(filename)
filename = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
torch.save(graph, filename)
graph2 = torch.load(filename)
assert graph.total_num_nodes == graph2.total_num_nodes
assert graph.total_num_edges == graph2.total_num_edges
......@@ -1103,7 +1103,7 @@ def test_homo_graph_on_shared_memory(
shm_name = "test_homo_g"
graph1 = graph.copy_to_shared_memory(shm_name)
graph2 = gb.load_from_shared_memory(shm_name, graph.metadata)
graph2 = gb.load_from_shared_memory(shm_name)
assert graph1.total_num_nodes == total_num_nodes
assert graph1.total_num_nodes == total_num_nodes
......@@ -1181,7 +1181,7 @@ def test_hetero_graph_on_shared_memory(
shm_name = "test_hetero_g"
graph1 = graph.copy_to_shared_memory(shm_name)
graph2 = gb.load_from_shared_memory(shm_name, graph.metadata)
graph2 = gb.load_from_shared_memory(shm_name)
assert graph1.total_num_nodes == total_num_nodes
assert graph1.total_num_nodes == total_num_nodes
......
......@@ -1008,8 +1008,8 @@ def test_OnDiskDataset_Graph_homogeneous():
graph = gb.from_fused_csc(csc_indptr, indices)
with tempfile.TemporaryDirectory() as test_dir:
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.tar")
gb.save_fused_csc_sampling_graph(graph, graph_path)
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
torch.save(graph, graph_path)
yaml_content = f"""
graph_topology:
......@@ -1046,8 +1046,8 @@ def test_OnDiskDataset_Graph_heterogeneous():
)
with tempfile.TemporaryDirectory() as test_dir:
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.tar")
gb.save_fused_csc_sampling_graph(graph, graph_path)
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
torch.save(graph, graph_path)
yaml_content = f"""
graph_topology:
......@@ -1119,12 +1119,8 @@ def test_OnDiskDataset_preprocess_homogeneous():
assert "graph" not in processed_dataset
assert "graph_topology" in processed_dataset
fused_csc_sampling_graph = (
gb.fused_csc_sampling_graph.load_fused_csc_sampling_graph(
os.path.join(
test_dir, processed_dataset["graph_topology"]["path"]
)
)
fused_csc_sampling_graph = torch.load(
os.path.join(test_dir, processed_dataset["graph_topology"]["path"])
)
assert fused_csc_sampling_graph.total_num_nodes == num_nodes
assert fused_csc_sampling_graph.total_num_edges == num_edges
......@@ -1166,12 +1162,8 @@ def test_OnDiskDataset_preprocess_homogeneous():
)
with open(output_file, "rb") as f:
processed_dataset = yaml.load(f, Loader=yaml.Loader)
fused_csc_sampling_graph = (
gb.fused_csc_sampling_graph.load_fused_csc_sampling_graph(
os.path.join(
test_dir, processed_dataset["graph_topology"]["path"]
)
)
fused_csc_sampling_graph = torch.load(
os.path.join(test_dir, processed_dataset["graph_topology"]["path"])
)
assert (
fused_csc_sampling_graph.edge_attributes is not None
......@@ -1325,7 +1317,7 @@ def test_OnDiskDataset_preprocess_yaml_content_unix():
dataset_name: {dataset_name}
graph_topology:
type: FusedCSCSamplingGraph
path: preprocessed/fused_csc_sampling_graph.tar
path: preprocessed/fused_csc_sampling_graph.pt
feature_data:
- domain: node
type: null
......@@ -1479,7 +1471,7 @@ def test_OnDiskDataset_preprocess_yaml_content_windows():
dataset_name: {dataset_name}
graph_topology:
type: FusedCSCSamplingGraph
path: preprocessed\\fused_csc_sampling_graph.tar
path: preprocessed\\fused_csc_sampling_graph.pt
feature_data:
- domain: node
type: null
......@@ -1836,8 +1828,8 @@ def test_OnDiskDataset_all_nodes_set_homo():
graph = gb.from_fused_csc(csc_indptr, indices)
with tempfile.TemporaryDirectory() as test_dir:
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.tar")
gb.save_fused_csc_sampling_graph(graph, graph_path)
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
torch.save(graph, graph_path)
yaml_content = f"""
graph_topology:
......@@ -1873,8 +1865,8 @@ def test_OnDiskDataset_all_nodes_set_hetero():
)
with tempfile.TemporaryDirectory() as test_dir:
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.tar")
gb.save_fused_csc_sampling_graph(graph, graph_path)
graph_path = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
torch.save(graph, graph_path)
yaml_content = f"""
graph_topology:
......@@ -1999,7 +1991,7 @@ def test_BuiltinDataset():
"""Test BuiltinDataset."""
with tempfile.TemporaryDirectory() as test_dir:
# Case 1: download from DGL S3 storage.
dataset_name = "test-dataset-231204"
dataset_name = "test-dataset-231207"
# Add dataset to the builtin dataset list for testing only.
gb.BuiltinDataset._all_datasets.append(dataset_name)
dataset = gb.BuiltinDataset(name=dataset_name, root=test_dir).load()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment