Unverified Commit 41b905da authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

Fix #1826 (#1835)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 54b71412
"""For Graph Serialization""" """For Graph Serialization"""
from __future__ import absolute_import from __future__ import absolute_import
import os
from ..graph import DGLGraph from ..graph import DGLGraph
from ..heterograph import DGLHeteroGraph from ..heterograph import DGLHeteroGraph
from .._ffi.object import ObjectBase, register_object from .._ffi.object import ObjectBase, register_object
...@@ -70,7 +71,7 @@ def save_graphs(filename, g_list, labels=None): ...@@ -70,7 +71,7 @@ def save_graphs(filename, g_list, labels=None):
Parameters Parameters
---------- ----------
filename : str filename : str
File name to store graphs. File name to store graphs.
g_list: list g_list: list
DGLGraph or list of DGLGraph/DGLHeteroGraph DGLGraph or list of DGLGraph/DGLHeteroGraph
labels: dict[str, tensor] labels: dict[str, tensor]
...@@ -81,7 +82,7 @@ def save_graphs(filename, g_list, labels=None): ...@@ -81,7 +82,7 @@ def save_graphs(filename, g_list, labels=None):
>>> import dgl >>> import dgl
>>> import torch as th >>> import torch as th
Create :code:`DGLGraph`/:code:`DGLHeteroGraph` objects and initialize node Create :code:`DGLGraph`/:code:`DGLHeteroGraph` objects and initialize node
and edge features. and edge features.
>>> g1 = dgl.graph(([0, 1, 2], [1, 2, 3]) >>> g1 = dgl.graph(([0, 1, 2], [1, 2, 3])
...@@ -95,6 +96,13 @@ def save_graphs(filename, g_list, labels=None): ...@@ -95,6 +96,13 @@ def save_graphs(filename, g_list, labels=None):
>>> save_graphs("./data.bin", [g1, g2], graph_labels) >>> save_graphs("./data.bin", [g1, g2], graph_labels)
""" """
# if it is local file, do some sanity check
if filename.startswith('s3://') is False:
assert not os.path.isdir(filename), "filename {} is an existing directory.".format(filename)
f_path, _ = os.path.split(filename)
if not os.path.exists(f_path):
os.makedirs(f_path)
g_sample = g_list[0] if isinstance(g_list, list) else g_list g_sample = g_list[0] if isinstance(g_list, list) else g_list
if isinstance(g_sample, DGLGraph): if isinstance(g_sample, DGLGraph):
save_dglgraphs(filename, g_list, labels) save_dglgraphs(filename, g_list, labels)
...@@ -149,6 +157,9 @@ def load_graphs(filename, idx_list=None): ...@@ -149,6 +157,9 @@ def load_graphs(filename, idx_list=None):
>>> glist, label_dict = load_graphs("./data.bin", [0]) # glist will be [g1] >>> glist, label_dict = load_graphs("./data.bin", [0]) # glist will be [g1]
""" """
# if it is local file, do some sanity check
assert filename.startswith('s3://') or os.path.exists(filename), "file {} does not exist.".format(filename)
version = _CAPI_GetFileVersion(filename) version = _CAPI_GetFileVersion(filename)
if version == 1: if version == 1:
return load_graph_v1(filename, idx_list) return load_graph_v1(filename, idx_list)
...@@ -204,6 +215,9 @@ def load_labels(filename): ...@@ -204,6 +215,9 @@ def load_labels(filename):
>>> label_dict = load_graphs("./data.bin") >>> label_dict = load_graphs("./data.bin")
""" """
# if it is local file, do some sanity check
assert filename.startswith('s3://') or os.path.exists(filename), "file {} does not exist.".format(filename)
version = _CAPI_GetFileVersion(filename) version = _CAPI_GetFileVersion(filename)
if version == 1: if version == 1:
return load_labels_v1(filename) return load_labels_v1(filename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment