"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "1b37545d79ca9d8a15b99c5b0fa546d55926383c"
Unverified Commit 41b905da authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

Fix #1826 (#1835)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 54b71412
"""For Graph Serialization""" """For Graph Serialization"""
from __future__ import absolute_import from __future__ import absolute_import
import os
from ..graph import DGLGraph from ..graph import DGLGraph
from ..heterograph import DGLHeteroGraph from ..heterograph import DGLHeteroGraph
from .._ffi.object import ObjectBase, register_object from .._ffi.object import ObjectBase, register_object
...@@ -95,6 +96,13 @@ def save_graphs(filename, g_list, labels=None): ...@@ -95,6 +96,13 @@ def save_graphs(filename, g_list, labels=None):
>>> save_graphs("./data.bin", [g1, g2], graph_labels) >>> save_graphs("./data.bin", [g1, g2], graph_labels)
""" """
# if it is local file, do some sanity check
if filename.startswith('s3://') is False:
assert not os.path.isdir(filename), "filename {} is an existing directory.".format(filename)
f_path, _ = os.path.split(filename)
if not os.path.exists(f_path):
os.makedirs(f_path)
g_sample = g_list[0] if isinstance(g_list, list) else g_list g_sample = g_list[0] if isinstance(g_list, list) else g_list
if isinstance(g_sample, DGLGraph): if isinstance(g_sample, DGLGraph):
save_dglgraphs(filename, g_list, labels) save_dglgraphs(filename, g_list, labels)
...@@ -149,6 +157,9 @@ def load_graphs(filename, idx_list=None): ...@@ -149,6 +157,9 @@ def load_graphs(filename, idx_list=None):
>>> glist, label_dict = load_graphs("./data.bin", [0]) # glist will be [g1] >>> glist, label_dict = load_graphs("./data.bin", [0]) # glist will be [g1]
""" """
# if it is local file, do some sanity check
assert filename.startswith('s3://') or os.path.exists(filename), "file {} does not exist.".format(filename)
version = _CAPI_GetFileVersion(filename) version = _CAPI_GetFileVersion(filename)
if version == 1: if version == 1:
return load_graph_v1(filename, idx_list) return load_graph_v1(filename, idx_list)
...@@ -204,6 +215,9 @@ def load_labels(filename): ...@@ -204,6 +215,9 @@ def load_labels(filename):
>>> label_dict = load_graphs("./data.bin") >>> label_dict = load_graphs("./data.bin")
""" """
# if it is local file, do some sanity check
assert filename.startswith('s3://') or os.path.exists(filename), "file {} does not exist.".format(filename)
version = _CAPI_GetFileVersion(filename) version = _CAPI_GetFileVersion(filename)
if version == 1: if version == 1:
return load_labels_v1(filename) return load_labels_v1(filename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment