Unverified Commit c78ddee2 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature] enable to save graphs in multiple formats (#4266)

* [Feature] enable to save graphs in multiple formats

* use StreamWithCount to accomodate formats

* fix lint

* fix dynamic cast issue

* fix lint

* refine docstring and code naming

* update docstring
parent 6a460725
......@@ -26,6 +26,7 @@ enum class SparseFormat {
* \brief Sparse format codes
*/
const dgl_format_code_t ALL_CODE = 0x7;
const dgl_format_code_t ANY_CODE = 0x0;
const dgl_format_code_t COO_CODE = 0x1;
const dgl_format_code_t CSR_CODE = 0x2;
const dgl_format_code_t CSC_CODE = 0x4;
......
......@@ -82,7 +82,7 @@ class GraphData(ObjectBase):
return g
def save_graphs(filename, g_list, labels=None):
def save_graphs(filename, g_list, labels=None, formats=None):
r"""Save graphs and optionally their labels to file.
Besides saving to local files, DGL supports writing the graphs directly
......@@ -101,6 +101,12 @@ def save_graphs(filename, g_list, labels=None):
The graphs to be saved.
labels: dict[str, Tensor]
labels should be dict of tensors, with str as keys
formats: str or list[str]
Save graph in specified formats. It could be any combination of
``coo``, ``csc`` and ``csr``. If not specified, save one format
only according to what format is available. If multiple formats
are available, selection priority from high to low is ``coo``,
``csc``, ``csr``.
Examples
----------
......@@ -138,7 +144,7 @@ def save_graphs(filename, g_list, labels=None):
if (
type(g_sample) == DGLHeteroGraph
): # Doesn't support DGLHeteroGraph's derived class
save_heterographs(filename, g_list, labels)
save_heterographs(filename, g_list, labels, formats)
else:
raise DGLError(
"Invalid argument g_list. Must be a DGLGraph or a list of DGLGraphs."
......
......@@ -19,7 +19,7 @@ def tensor_dict_to_ndarray_dict(tensor_dict):
return convert_to_strmap(ndarray_dict)
def save_heterographs(filename, g_list, labels):
def save_heterographs(filename, g_list, labels, formats):
"""Save heterographs into file"""
if labels is None:
labels = {}
......@@ -29,11 +29,14 @@ def save_heterographs(filename, g_list, labels):
[type(g) == DGLHeteroGraph for g in g_list]
), "Invalid DGLHeteroGraph in g_list argument"
gdata_list = [HeteroGraphData.create(g) for g in g_list]
if formats is None:
formats = []
elif isinstance(formats, str):
formats = [formats]
_CAPI_SaveHeteroGraphData(
filename, gdata_list, tensor_dict_to_ndarray_dict(labels)
filename, gdata_list, tensor_dict_to_ndarray_dict(labels), formats
)
@register_object("heterograph_serialize.HeteroGraphData")
class HeteroGraphData(ObjectBase):
"""Object to hold the data to be stored for DGLHeteroGraph"""
......
/*!
* Copyright (c) 2019 by Contributors
* \file graph/serialize/streamwithcount.h
* \file graph/serialize/dglstream.h
* \brief Graph serialization header
*/
#ifndef DGL_GRAPH_SERIALIZE_STREAMWITHCOUNT_H_
#define DGL_GRAPH_SERIALIZE_STREAMWITHCOUNT_H_
#ifndef DGL_GRAPH_SERIALIZE_DGLSTREAM_H_
#define DGL_GRAPH_SERIALIZE_DGLSTREAM_H_
#include <dgl/aten/spmat.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <memory>
namespace dgl {
namespace serialize {
/*!
* \brief StreamWithCount counts the bytes that already written into the
* \brief DGLStream counts the bytes that already written into the
* underlying stream.
*/
class StreamWithCount : dmlc::Stream {
class DGLStream : public dmlc::Stream {
public:
static StreamWithCount *Create(const char *uri, const char *const flag,
bool allow_null = false) {
return new StreamWithCount(uri, flag, allow_null);
/*! \brief create a new DGLStream instance */
static DGLStream *Create(const char *uri, const char *const flag,
bool allow_null, dgl_format_code_t formats) {
return new DGLStream(uri, flag, allow_null, formats);
}
size_t Read(void *ptr, size_t size) override {
......@@ -37,11 +42,21 @@ class StreamWithCount : dmlc::Stream {
uint64_t Count() const { return count_; }
uint64_t FormatsToSave() const { return formats_to_save_; }
private:
StreamWithCount(const char *uri, const char *const flag, bool allow_null)
: strm_(dmlc::Stream::Create(uri, flag, allow_null)) {}
DGLStream(const char *uri, const char *const flag, bool allow_null,
dgl_format_code_t formats)
: strm_(dmlc::Stream::Create(uri, flag, allow_null)), formats_to_save_(formats) {
}
// stream for serialization
std::unique_ptr<dmlc::Stream> strm_;
// size of already written to stream
uint64_t count_ = 0;
// formats to use when saving graph
const dgl_format_code_t formats_to_save_ = ANY_CODE;
};
} // namespace serialize
} // namespace dgl
#endif // DGL_GRAPH_SERIALIZE_STREAMWITHCOUNT_H_
#endif // DGL_GRAPH_SERIALIZE_DGLSTREAM_H_
......@@ -49,7 +49,7 @@
#include "../heterograph.h"
#include "./graph_serialize.h"
#include "./streamwithcount.h"
#include "./dglstream.h"
#include "dmlc/memory_io.h"
namespace dgl {
......@@ -62,9 +62,10 @@ using dmlc::io::FileSystem;
using dmlc::io::URI;
bool SaveHeteroGraphs(std::string filename, List<HeteroGraphData> hdata,
const std::vector<NamedTensor> &nd_list) {
auto fs = std::unique_ptr<StreamWithCount>(
StreamWithCount::Create(filename.c_str(), "w", false));
const std::vector<NamedTensor> &nd_list,
dgl_format_code_t formats) {
auto fs = std::unique_ptr<DGLStream>(
DGLStream::Create(filename.c_str(), "w", false, formats));
CHECK(fs->IsValid()) << "File name " << filename << " is not a valid name";
// Write DGL MetaData
......@@ -221,12 +222,18 @@ DGL_REGISTER_GLOBAL("data.heterograph_serialize._CAPI_SaveHeteroGraphData")
std::string filename = args[0];
List<HeteroGraphData> hgdata = args[1];
Map<std::string, Value> nd_map = args[2];
List<Value> formats = args[3];
std::vector<SparseFormat> formats_vec;
for (const auto& val : formats) {
formats_vec.push_back(ParseSparseFormat(val->data));
}
const auto formats_code = SparseFormatsToCode(formats_vec);
std::vector<NamedTensor> nd_list;
for (auto kv : nd_map) {
NDArray ndarray = static_cast<NDArray>(kv.second->data);
nd_list.emplace_back(kv.first, ndarray);
}
*rv = dgl::serialize::SaveHeteroGraphs(filename, hgdata, nd_list);
*rv = dgl::serialize::SaveHeteroGraphs(filename, hgdata, nd_list, formats_code);
});
DGL_REGISTER_GLOBAL(
......
......@@ -10,6 +10,7 @@
#include "../c_api_common.h"
#include "./unit_graph.h"
#include "./serialize/dglstream.h"
namespace dgl {
......@@ -1648,7 +1649,14 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
int64_t save_format_code, formats_code;
CHECK(fs->Read(&save_format_code)) << "Invalid format";
CHECK(fs->Read(&formats_code)) << "Invalid format";
auto save_format = static_cast<SparseFormat>(save_format_code);
dgl_format_code_t save_formats = ANY_CODE;
if (save_format_code >> 32) {
save_formats =
static_cast<dgl_format_code_t>(0xffffffff & save_format_code);
} else {
save_formats =
SparseFormatsToCode({static_cast<SparseFormat>(save_format_code)});
}
if (formats_code >> 32) {
formats_ = static_cast<dgl_format_code_t>(0xffffffff & formats_code);
} else {
......@@ -1672,19 +1680,17 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
}
}
switch (save_format) {
case SparseFormat::kCOO:
if (save_formats & COO_CODE) {
fs->Read(&coo_);
break;
case SparseFormat::kCSR:
}
if (save_formats & CSR_CODE) {
fs->Read(&out_csr_);
break;
case SparseFormat::kCSC:
}
if (save_formats & CSC_CODE) {
fs->Read(&in_csr_);
break;
default:
}
if (!coo_ && !out_csr_ && !in_csr_) {
LOG(FATAL) << "unsupported format code";
break;
}
if (!in_csr_) {
......@@ -1707,22 +1713,24 @@ void UnitGraph::Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_UnitGraphMagic);
// Didn't write UnitGraph::meta_graph_, since it's included in the underlying
// sparse matrix
auto avail_fmt = SelectFormat(ALL_CODE);
fs->Write(static_cast<int64_t>(avail_fmt));
auto save_formats = SparseFormatsToCode({SelectFormat(ALL_CODE)});
auto fstream = dynamic_cast<dgl::serialize::DGLStream *>(fs);
if (fstream) {
auto formats = fstream->FormatsToSave();
save_formats = formats == ANY_CODE
? SparseFormatsToCode({SelectFormat(ALL_CODE)})
: formats;
}
fs->Write(static_cast<int64_t>(save_formats | 0x100000000));
fs->Write(static_cast<int64_t>(formats_ | 0x100000000));
switch (avail_fmt) {
case SparseFormat::kCOO:
if (save_formats & COO_CODE) {
fs->Write(GetCOO());
break;
case SparseFormat::kCSR:
}
if (save_formats & CSR_CODE) {
fs->Write(GetOutCSR());
break;
case SparseFormat::kCSC:
}
if (save_formats & CSC_CODE) {
fs->Write(GetInCSR());
break;
default:
LOG(FATAL) << "unsupported format code";
break;
}
}
......
......@@ -364,19 +364,84 @@ def test_serialize_heterograph_s3():
assert np.allclose(F.asnumpy(edges[1]), np.array([1, 2, 3]))
if __name__ == "__main__":
pass
# test_graph_serialize_with_feature(True)
# test_graph_serialize_with_feature(False)
# test_graph_serialize_without_feature(True)
# test_graph_serialize_without_feature(False)
# test_graph_serialize_with_labels(True)
# test_graph_serialize_with_labels(False)
# test_serialize_tensors()
# test_serialize_empty_dict()
# test_load_old_files1()
test_load_old_files2()
# test_serialize_heterograph()
# test_serialize_heterograph_s3()
# test_deserialize_old_heterograph_file()
# create_old_heterograph_files()
@unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented")
@pytest.mark.parametrize("is_hetero", [True, False])
@pytest.mark.parametrize(
"formats",
[
"coo",
"csr",
"csc",
["coo", "csc"],
["coo", "csr"],
["csc", "csr"],
["coo", "csr", "csc"],
],
)
def test_graph_serialize_with_formats(is_hetero, formats):
num_graphs = 100
g_list = [generate_rand_graph(30, is_hetero) for _ in range(num_graphs)]
# create a temporary file and immediately release it so DGL can open it.
f = tempfile.NamedTemporaryFile(delete=False)
path = f.name
f.close()
dgl.save_graphs(path, g_list, formats=formats)
idx_list = np.random.permutation(np.arange(num_graphs)).tolist()
loadg_list, _ = dgl.load_graphs(path, idx_list)
idx = idx_list[0]
load_g = loadg_list[0]
g_formats = load_g.formats()
# verify formats
if not isinstance(formats, list):
formats = [formats]
for fmt in formats:
assert fmt in g_formats["created"]
assert F.allclose(load_g.nodes(), g_list[idx].nodes())
load_edges = load_g.all_edges("uv", "eid")
g_edges = g_list[idx].all_edges("uv", "eid")
assert F.allclose(load_edges[0], g_edges[0])
assert F.allclose(load_edges[1], g_edges[1])
os.unlink(path)
@unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented")
def test_graph_serialize_with_restricted_formats():
g = dgl.rand_graph(100, 200)
g = g.formats(["coo"])
g_list = [g]
# create a temporary file and immediately release it so DGL can open it.
f = tempfile.NamedTemporaryFile(delete=False)
path = f.name
f.close()
expect_except = False
try:
dgl.save_graphs(path, g_list, formats=["csr"])
except:
expect_except = True
assert expect_except
os.unlink(path)
@unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented")
def test_deserialize_old_graph():
num_nodes = 100
num_edges = 200
path = os.path.join(os.path.dirname(__file__), "data/graph_0.9a220622.dgl")
g_list, _ = dgl.load_graphs(path)
g = g_list[0]
assert "coo" in g.formats()["created"]
assert "csr" in g.formats()["not created"]
assert "csc" in g.formats()["not created"]
assert num_nodes == g.num_nodes()
assert num_edges == g.num_edges()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment