Unverified Commit 41b905da authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

Fix #1826 (#1835)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 54b71412
"""For Graph Serialization"""
from __future__ import absolute_import
import os
from ..graph import DGLGraph
from ..heterograph import DGLHeteroGraph
from .._ffi.object import ObjectBase, register_object
......@@ -95,6 +96,13 @@ def save_graphs(filename, g_list, labels=None):
>>> save_graphs("./data.bin", [g1, g2], graph_labels)
"""
# if it is local file, do some sanity check
if filename.startswith('s3://') is False:
assert not os.path.isdir(filename), "filename {} is an existing directory.".format(filename)
f_path, _ = os.path.split(filename)
if not os.path.exists(f_path):
os.makedirs(f_path)
g_sample = g_list[0] if isinstance(g_list, list) else g_list
if isinstance(g_sample, DGLGraph):
save_dglgraphs(filename, g_list, labels)
......@@ -149,6 +157,9 @@ def load_graphs(filename, idx_list=None):
>>> glist, label_dict = load_graphs("./data.bin", [0]) # glist will be [g1]
"""
# if it is local file, do some sanity check
assert filename.startswith('s3://') or os.path.exists(filename), "file {} does not exist.".format(filename)
version = _CAPI_GetFileVersion(filename)
if version == 1:
return load_graph_v1(filename, idx_list)
......@@ -204,6 +215,9 @@ def load_labels(filename):
>>> label_dict = load_graphs("./data.bin")
"""
# if it is local file, do some sanity check
assert filename.startswith('s3://') or os.path.exists(filename), "file {} does not exist.".format(filename)
version = _CAPI_GetFileVersion(filename)
if version == 1:
return load_labels_v1(filename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment