Unverified Commit c78ddee2 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature] enable to save graphs in multiple formats (#4266)

* [Feature] enable to save graphs in multiple formats

* use StreamWithCount to accomodate formats

* fix lint

* fix dynamic cast issue

* fix lint

* refine docstring and code naming

* update docstring
parent 6a460725
...@@ -26,6 +26,7 @@ enum class SparseFormat { ...@@ -26,6 +26,7 @@ enum class SparseFormat {
* \brief Sparse format codes * \brief Sparse format codes
*/ */
const dgl_format_code_t ALL_CODE = 0x7; const dgl_format_code_t ALL_CODE = 0x7;
const dgl_format_code_t ANY_CODE = 0x0;
const dgl_format_code_t COO_CODE = 0x1; const dgl_format_code_t COO_CODE = 0x1;
const dgl_format_code_t CSR_CODE = 0x2; const dgl_format_code_t CSR_CODE = 0x2;
const dgl_format_code_t CSC_CODE = 0x4; const dgl_format_code_t CSC_CODE = 0x4;
......
...@@ -82,7 +82,7 @@ class GraphData(ObjectBase): ...@@ -82,7 +82,7 @@ class GraphData(ObjectBase):
return g return g
def save_graphs(filename, g_list, labels=None): def save_graphs(filename, g_list, labels=None, formats=None):
r"""Save graphs and optionally their labels to file. r"""Save graphs and optionally their labels to file.
Besides saving to local files, DGL supports writing the graphs directly Besides saving to local files, DGL supports writing the graphs directly
...@@ -101,6 +101,12 @@ def save_graphs(filename, g_list, labels=None): ...@@ -101,6 +101,12 @@ def save_graphs(filename, g_list, labels=None):
The graphs to be saved. The graphs to be saved.
labels: dict[str, Tensor] labels: dict[str, Tensor]
labels should be dict of tensors, with str as keys labels should be dict of tensors, with str as keys
formats: str or list[str]
Save graph in specified formats. It could be any combination of
``coo``, ``csc`` and ``csr``. If not specified, save one format
only according to what format is available. If multiple formats
are available, selection priority from high to low is ``coo``,
``csc``, ``csr``.
Examples Examples
---------- ----------
...@@ -138,7 +144,7 @@ def save_graphs(filename, g_list, labels=None): ...@@ -138,7 +144,7 @@ def save_graphs(filename, g_list, labels=None):
if ( if (
type(g_sample) == DGLHeteroGraph type(g_sample) == DGLHeteroGraph
): # Doesn't support DGLHeteroGraph's derived class ): # Doesn't support DGLHeteroGraph's derived class
save_heterographs(filename, g_list, labels) save_heterographs(filename, g_list, labels, formats)
else: else:
raise DGLError( raise DGLError(
"Invalid argument g_list. Must be a DGLGraph or a list of DGLGraphs." "Invalid argument g_list. Must be a DGLGraph or a list of DGLGraphs."
......
...@@ -19,7 +19,7 @@ def tensor_dict_to_ndarray_dict(tensor_dict): ...@@ -19,7 +19,7 @@ def tensor_dict_to_ndarray_dict(tensor_dict):
return convert_to_strmap(ndarray_dict) return convert_to_strmap(ndarray_dict)
def save_heterographs(filename, g_list, labels): def save_heterographs(filename, g_list, labels, formats):
"""Save heterographs into file""" """Save heterographs into file"""
if labels is None: if labels is None:
labels = {} labels = {}
...@@ -29,11 +29,14 @@ def save_heterographs(filename, g_list, labels): ...@@ -29,11 +29,14 @@ def save_heterographs(filename, g_list, labels):
[type(g) == DGLHeteroGraph for g in g_list] [type(g) == DGLHeteroGraph for g in g_list]
), "Invalid DGLHeteroGraph in g_list argument" ), "Invalid DGLHeteroGraph in g_list argument"
gdata_list = [HeteroGraphData.create(g) for g in g_list] gdata_list = [HeteroGraphData.create(g) for g in g_list]
if formats is None:
formats = []
elif isinstance(formats, str):
formats = [formats]
_CAPI_SaveHeteroGraphData( _CAPI_SaveHeteroGraphData(
filename, gdata_list, tensor_dict_to_ndarray_dict(labels) filename, gdata_list, tensor_dict_to_ndarray_dict(labels), formats
) )
@register_object("heterograph_serialize.HeteroGraphData") @register_object("heterograph_serialize.HeteroGraphData")
class HeteroGraphData(ObjectBase): class HeteroGraphData(ObjectBase):
"""Object to hold the data to be stored for DGLHeteroGraph""" """Object to hold the data to be stored for DGLHeteroGraph"""
......
/*! /*!
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* \file graph/serialize/streamwithcount.h * \file graph/serialize/dglstream.h
* \brief Graph serialization header * \brief Graph serialization header
*/ */
#ifndef DGL_GRAPH_SERIALIZE_STREAMWITHCOUNT_H_ #ifndef DGL_GRAPH_SERIALIZE_DGLSTREAM_H_
#define DGL_GRAPH_SERIALIZE_STREAMWITHCOUNT_H_ #define DGL_GRAPH_SERIALIZE_DGLSTREAM_H_
#include <dgl/aten/spmat.h>
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/type_traits.h> #include <dmlc/type_traits.h>
#include <memory> #include <memory>
namespace dgl {
namespace serialize {
/*! /*!
* \brief StreamWithCount counts the bytes that already written into the * \brief DGLStream counts the bytes that already written into the
* underlying stream. * underlying stream.
*/ */
class StreamWithCount : dmlc::Stream { class DGLStream : public dmlc::Stream {
public: public:
static StreamWithCount *Create(const char *uri, const char *const flag, /*! \brief create a new DGLStream instance */
bool allow_null = false) { static DGLStream *Create(const char *uri, const char *const flag,
return new StreamWithCount(uri, flag, allow_null); bool allow_null, dgl_format_code_t formats) {
return new DGLStream(uri, flag, allow_null, formats);
} }
size_t Read(void *ptr, size_t size) override { size_t Read(void *ptr, size_t size) override {
...@@ -37,11 +42,21 @@ class StreamWithCount : dmlc::Stream { ...@@ -37,11 +42,21 @@ class StreamWithCount : dmlc::Stream {
uint64_t Count() const { return count_; } uint64_t Count() const { return count_; }
uint64_t FormatsToSave() const { return formats_to_save_; }
private: private:
StreamWithCount(const char *uri, const char *const flag, bool allow_null) DGLStream(const char *uri, const char *const flag, bool allow_null,
: strm_(dmlc::Stream::Create(uri, flag, allow_null)) {} dgl_format_code_t formats)
: strm_(dmlc::Stream::Create(uri, flag, allow_null)), formats_to_save_(formats) {
}
// stream for serialization
std::unique_ptr<dmlc::Stream> strm_; std::unique_ptr<dmlc::Stream> strm_;
// size of already written to stream
uint64_t count_ = 0; uint64_t count_ = 0;
// formats to use when saving graph
const dgl_format_code_t formats_to_save_ = ANY_CODE;
}; };
} // namespace serialize
} // namespace dgl
#endif // DGL_GRAPH_SERIALIZE_STREAMWITHCOUNT_H_ #endif // DGL_GRAPH_SERIALIZE_DGLSTREAM_H_
...@@ -49,7 +49,7 @@ ...@@ -49,7 +49,7 @@
#include "../heterograph.h" #include "../heterograph.h"
#include "./graph_serialize.h" #include "./graph_serialize.h"
#include "./streamwithcount.h" #include "./dglstream.h"
#include "dmlc/memory_io.h" #include "dmlc/memory_io.h"
namespace dgl { namespace dgl {
...@@ -62,9 +62,10 @@ using dmlc::io::FileSystem; ...@@ -62,9 +62,10 @@ using dmlc::io::FileSystem;
using dmlc::io::URI; using dmlc::io::URI;
bool SaveHeteroGraphs(std::string filename, List<HeteroGraphData> hdata, bool SaveHeteroGraphs(std::string filename, List<HeteroGraphData> hdata,
const std::vector<NamedTensor> &nd_list) { const std::vector<NamedTensor> &nd_list,
auto fs = std::unique_ptr<StreamWithCount>( dgl_format_code_t formats) {
StreamWithCount::Create(filename.c_str(), "w", false)); auto fs = std::unique_ptr<DGLStream>(
DGLStream::Create(filename.c_str(), "w", false, formats));
CHECK(fs->IsValid()) << "File name " << filename << " is not a valid name"; CHECK(fs->IsValid()) << "File name " << filename << " is not a valid name";
// Write DGL MetaData // Write DGL MetaData
...@@ -221,12 +222,18 @@ DGL_REGISTER_GLOBAL("data.heterograph_serialize._CAPI_SaveHeteroGraphData") ...@@ -221,12 +222,18 @@ DGL_REGISTER_GLOBAL("data.heterograph_serialize._CAPI_SaveHeteroGraphData")
std::string filename = args[0]; std::string filename = args[0];
List<HeteroGraphData> hgdata = args[1]; List<HeteroGraphData> hgdata = args[1];
Map<std::string, Value> nd_map = args[2]; Map<std::string, Value> nd_map = args[2];
List<Value> formats = args[3];
std::vector<SparseFormat> formats_vec;
for (const auto& val : formats) {
formats_vec.push_back(ParseSparseFormat(val->data));
}
const auto formats_code = SparseFormatsToCode(formats_vec);
std::vector<NamedTensor> nd_list; std::vector<NamedTensor> nd_list;
for (auto kv : nd_map) { for (auto kv : nd_map) {
NDArray ndarray = static_cast<NDArray>(kv.second->data); NDArray ndarray = static_cast<NDArray>(kv.second->data);
nd_list.emplace_back(kv.first, ndarray); nd_list.emplace_back(kv.first, ndarray);
} }
*rv = dgl::serialize::SaveHeteroGraphs(filename, hgdata, nd_list); *rv = dgl::serialize::SaveHeteroGraphs(filename, hgdata, nd_list, formats_code);
}); });
DGL_REGISTER_GLOBAL( DGL_REGISTER_GLOBAL(
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./unit_graph.h" #include "./unit_graph.h"
#include "./serialize/dglstream.h"
namespace dgl { namespace dgl {
...@@ -1648,7 +1649,14 @@ bool UnitGraph::Load(dmlc::Stream* fs) { ...@@ -1648,7 +1649,14 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
int64_t save_format_code, formats_code; int64_t save_format_code, formats_code;
CHECK(fs->Read(&save_format_code)) << "Invalid format"; CHECK(fs->Read(&save_format_code)) << "Invalid format";
CHECK(fs->Read(&formats_code)) << "Invalid format"; CHECK(fs->Read(&formats_code)) << "Invalid format";
auto save_format = static_cast<SparseFormat>(save_format_code); dgl_format_code_t save_formats = ANY_CODE;
if (save_format_code >> 32) {
save_formats =
static_cast<dgl_format_code_t>(0xffffffff & save_format_code);
} else {
save_formats =
SparseFormatsToCode({static_cast<SparseFormat>(save_format_code)});
}
if (formats_code >> 32) { if (formats_code >> 32) {
formats_ = static_cast<dgl_format_code_t>(0xffffffff & formats_code); formats_ = static_cast<dgl_format_code_t>(0xffffffff & formats_code);
} else { } else {
...@@ -1672,19 +1680,17 @@ bool UnitGraph::Load(dmlc::Stream* fs) { ...@@ -1672,19 +1680,17 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
} }
} }
switch (save_format) { if (save_formats & COO_CODE) {
case SparseFormat::kCOO:
fs->Read(&coo_); fs->Read(&coo_);
break; }
case SparseFormat::kCSR: if (save_formats & CSR_CODE) {
fs->Read(&out_csr_); fs->Read(&out_csr_);
break; }
case SparseFormat::kCSC: if (save_formats & CSC_CODE) {
fs->Read(&in_csr_); fs->Read(&in_csr_);
break; }
default: if (!coo_ && !out_csr_ && !in_csr_) {
LOG(FATAL) << "unsupported format code"; LOG(FATAL) << "unsupported format code";
break;
} }
if (!in_csr_) { if (!in_csr_) {
...@@ -1707,22 +1713,24 @@ void UnitGraph::Save(dmlc::Stream* fs) const { ...@@ -1707,22 +1713,24 @@ void UnitGraph::Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_UnitGraphMagic); fs->Write(kDGLSerialize_UnitGraphMagic);
// Didn't write UnitGraph::meta_graph_, since it's included in the underlying // Didn't write UnitGraph::meta_graph_, since it's included in the underlying
// sparse matrix // sparse matrix
auto avail_fmt = SelectFormat(ALL_CODE); auto save_formats = SparseFormatsToCode({SelectFormat(ALL_CODE)});
fs->Write(static_cast<int64_t>(avail_fmt)); auto fstream = dynamic_cast<dgl::serialize::DGLStream *>(fs);
if (fstream) {
auto formats = fstream->FormatsToSave();
save_formats = formats == ANY_CODE
? SparseFormatsToCode({SelectFormat(ALL_CODE)})
: formats;
}
fs->Write(static_cast<int64_t>(save_formats | 0x100000000));
fs->Write(static_cast<int64_t>(formats_ | 0x100000000)); fs->Write(static_cast<int64_t>(formats_ | 0x100000000));
switch (avail_fmt) { if (save_formats & COO_CODE) {
case SparseFormat::kCOO:
fs->Write(GetCOO()); fs->Write(GetCOO());
break; }
case SparseFormat::kCSR: if (save_formats & CSR_CODE) {
fs->Write(GetOutCSR()); fs->Write(GetOutCSR());
break; }
case SparseFormat::kCSC: if (save_formats & CSC_CODE) {
fs->Write(GetInCSR()); fs->Write(GetInCSR());
break;
default:
LOG(FATAL) << "unsupported format code";
break;
} }
} }
......
...@@ -364,19 +364,84 @@ def test_serialize_heterograph_s3(): ...@@ -364,19 +364,84 @@ def test_serialize_heterograph_s3():
assert np.allclose(F.asnumpy(edges[1]), np.array([1, 2, 3])) assert np.allclose(F.asnumpy(edges[1]), np.array([1, 2, 3]))
if __name__ == "__main__": @unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented")
pass @pytest.mark.parametrize("is_hetero", [True, False])
# test_graph_serialize_with_feature(True) @pytest.mark.parametrize(
# test_graph_serialize_with_feature(False) "formats",
# test_graph_serialize_without_feature(True) [
# test_graph_serialize_without_feature(False) "coo",
# test_graph_serialize_with_labels(True) "csr",
# test_graph_serialize_with_labels(False) "csc",
# test_serialize_tensors() ["coo", "csc"],
# test_serialize_empty_dict() ["coo", "csr"],
# test_load_old_files1() ["csc", "csr"],
test_load_old_files2() ["coo", "csr", "csc"],
# test_serialize_heterograph() ],
# test_serialize_heterograph_s3() )
# test_deserialize_old_heterograph_file() def test_graph_serialize_with_formats(is_hetero, formats):
# create_old_heterograph_files() num_graphs = 100
g_list = [generate_rand_graph(30, is_hetero) for _ in range(num_graphs)]
# create a temporary file and immediately release it so DGL can open it.
f = tempfile.NamedTemporaryFile(delete=False)
path = f.name
f.close()
dgl.save_graphs(path, g_list, formats=formats)
idx_list = np.random.permutation(np.arange(num_graphs)).tolist()
loadg_list, _ = dgl.load_graphs(path, idx_list)
idx = idx_list[0]
load_g = loadg_list[0]
g_formats = load_g.formats()
# verify formats
if not isinstance(formats, list):
formats = [formats]
for fmt in formats:
assert fmt in g_formats["created"]
assert F.allclose(load_g.nodes(), g_list[idx].nodes())
load_edges = load_g.all_edges("uv", "eid")
g_edges = g_list[idx].all_edges("uv", "eid")
assert F.allclose(load_edges[0], g_edges[0])
assert F.allclose(load_edges[1], g_edges[1])
os.unlink(path)
@unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented")
def test_graph_serialize_with_restricted_formats():
g = dgl.rand_graph(100, 200)
g = g.formats(["coo"])
g_list = [g]
# create a temporary file and immediately release it so DGL can open it.
f = tempfile.NamedTemporaryFile(delete=False)
path = f.name
f.close()
expect_except = False
try:
dgl.save_graphs(path, g_list, formats=["csr"])
except:
expect_except = True
assert expect_except
os.unlink(path)
@unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented")
def test_deserialize_old_graph():
num_nodes = 100
num_edges = 200
path = os.path.join(os.path.dirname(__file__), "data/graph_0.9a220622.dgl")
g_list, _ = dgl.load_graphs(path)
g = g_list[0]
assert "coo" in g.formats()["created"]
assert "csr" in g.formats()["not created"]
assert "csc" in g.formats()["not created"]
assert num_nodes == g.num_nodes()
assert num_edges == g.num_edges()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment