"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "a7e4fbdc925a5968988ccadd6dffe7abe274dcdc"
Unverified Commit 41b905da authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

Fix #1826 (#1835)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 54b71412
"""For Graph Serialization"""
from __future__ import absolute_import
import os
from ..graph import DGLGraph
from ..heterograph import DGLHeteroGraph
from .._ffi.object import ObjectBase, register_object
......@@ -70,7 +71,7 @@ def save_graphs(filename, g_list, labels=None):
Parameters
----------
filename : str
File name to store graphs.
File name to store graphs.
g_list: list
DGLGraph or list of DGLGraph/DGLHeteroGraph
labels: dict[str, tensor]
......@@ -81,7 +82,7 @@ def save_graphs(filename, g_list, labels=None):
>>> import dgl
>>> import torch as th
Create :code:`DGLGraph`/:code:`DGLHeteroGraph` objects and initialize node
Create :code:`DGLGraph`/:code:`DGLHeteroGraph` objects and initialize node
and edge features.
>>> g1 = dgl.graph(([0, 1, 2], [1, 2, 3])
......@@ -95,6 +96,13 @@ def save_graphs(filename, g_list, labels=None):
>>> save_graphs("./data.bin", [g1, g2], graph_labels)
"""
# if it is local file, do some sanity check
if filename.startswith('s3://') is False:
assert not os.path.isdir(filename), "filename {} is an existing directory.".format(filename)
f_path, _ = os.path.split(filename)
if not os.path.exists(f_path):
os.makedirs(f_path)
g_sample = g_list[0] if isinstance(g_list, list) else g_list
if isinstance(g_sample, DGLGraph):
save_dglgraphs(filename, g_list, labels)
......@@ -149,6 +157,9 @@ def load_graphs(filename, idx_list=None):
>>> glist, label_dict = load_graphs("./data.bin", [0]) # glist will be [g1]
"""
# if it is local file, do some sanity check
assert filename.startswith('s3://') or os.path.exists(filename), "file {} does not exist.".format(filename)
version = _CAPI_GetFileVersion(filename)
if version == 1:
return load_graph_v1(filename, idx_list)
......@@ -204,6 +215,9 @@ def load_labels(filename):
>>> label_dict = load_graphs("./data.bin")
"""
# if it is local file, do some sanity check
assert filename.startswith('s3://') or os.path.exists(filename), "file {} does not exist.".format(filename)
version = _CAPI_GetFileVersion(filename)
if version == 1:
return load_labels_v1(filename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment