Unverified Commit 3234189b authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

Fix hdfs error when writing big data (#2186)

* update dmlc-core to latest

* fix hdfs
parent 55f3e926
...@@ -25,6 +25,16 @@ class StorageMetaData(ObjectBase): ...@@ -25,6 +25,16 @@ class StorageMetaData(ObjectBase):
""" """
def is_local_path(filepath):
return not (filepath.startswith("hdfs://") or
filepath.startswith("viewfs://") or
filepath.startswith("s3://"))
def check_local_file_exists(filename):
if is_local_path(filename) and not os.path.exists(filename):
raise DGLError("File {} does not exist.".format(filename))
@register_object("graph_serialize.GraphData") @register_object("graph_serialize.GraphData")
class GraphData(ObjectBase): class GraphData(ObjectBase):
"""GraphData Object""" """GraphData Object"""
...@@ -108,7 +118,7 @@ def save_graphs(filename, g_list, labels=None): ...@@ -108,7 +118,7 @@ def save_graphs(filename, g_list, labels=None):
load_graphs load_graphs
""" """
# if it is local file, do some sanity check # if it is local file, do some sanity check
if not filename.startswith('s3://'): if is_local_path(filename):
if os.path.isdir(filename): if os.path.isdir(filename):
raise DGLError("Filename {} is an existing directory.".format(filename)) raise DGLError("Filename {} is an existing directory.".format(filename))
f_path = os.path.dirname(filename) f_path = os.path.dirname(filename)
...@@ -161,9 +171,7 @@ def load_graphs(filename, idx_list=None): ...@@ -161,9 +171,7 @@ def load_graphs(filename, idx_list=None):
save_graphs save_graphs
""" """
# if it is local file, do some sanity check # if it is local file, do some sanity check
if not (filename.startswith('s3://') or os.path.exists(filename)): check_local_file_exists(filename)
raise DGLError("File {} does not exist.".format(filename))
version = _CAPI_GetFileVersion(filename) version = _CAPI_GetFileVersion(filename)
if version == 1: if version == 1:
dgl_warning( dgl_warning(
...@@ -222,7 +230,7 @@ def load_labels(filename): ...@@ -222,7 +230,7 @@ def load_labels(filename):
""" """
# if it is local file, do some sanity check # if it is local file, do some sanity check
assert filename.startswith('s3://') or os.path.exists(filename), "file {} does not exist.".format(filename) check_local_file_exists(filename)
version = _CAPI_GetFileVersion(filename) version = _CAPI_GetFileVersion(filename)
if version == 1: if version == 1:
......
Subproject commit 16c6f68c09af7ed2762cedcd2017307baaf875ed Subproject commit bfad207b448480783a1f428ae3d93d87032d8349
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment