Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
72d16f78
Unverified
Commit
72d16f78
authored
May 19, 2023
by
Rhett Ying
Committed by
GitHub
May 19, 2023
Browse files
[GraphBolt] optimize load/save logic (#5713)
parent
1e295c53
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
120 additions
and
182 deletions
+120
-182
graphbolt/include/graphbolt/csc_sampling_graph.h
graphbolt/include/graphbolt/csc_sampling_graph.h
+4
-30
graphbolt/include/graphbolt/serialize.h
graphbolt/include/graphbolt/serialize.h
+42
-107
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+25
-45
graphbolt/src/serialize.cc
graphbolt/src/serialize.cc
+49
-0
No files found.
graphbolt/include/graphbolt/csc_sampling_graph.h
View file @
72d16f78
...
@@ -96,6 +96,9 @@ struct HeteroInfo {
...
@@ -96,6 +96,9 @@ struct HeteroInfo {
*/
*/
class
CSCSamplingGraph
:
public
torch
::
CustomClassHolder
{
class
CSCSamplingGraph
:
public
torch
::
CustomClassHolder
{
public:
public:
/** @brief Default constructor. */
CSCSamplingGraph
()
=
default
;
/**
/**
* @brief Constructor for CSC with data.
* @brief Constructor for CSC with data.
* @param num_nodes The number of nodes in the graph.
* @param num_nodes The number of nodes in the graph.
...
@@ -203,33 +206,4 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -203,33 +206,4 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
}
// namespace sampling
}
// namespace sampling
}
// namespace graphbolt
}
// namespace graphbolt
/**
* @brief Overload stream operator to enable `torch::save()` and `torch.load()`
* for CSCSamplingGraph.
*/
namespace
torch
{
/**
* @brief Overload input stream operator for CSCSamplingGraph deserialization.
* @param archive Input stream for deserializing.
* @param graph CSCSamplingGraph.
*
* @return archive
*/
inline
serialize
::
InputArchive
&
operator
>>
(
serialize
::
InputArchive
&
archive
,
graphbolt
::
sampling
::
CSCSamplingGraph
&
graph
);
/**
* @brief Overload output stream operator for CSCSamplingGraph serialization.
* @param archive Output stream for serializing.
* @param graph CSCSamplingGraph.
*
* @return archive
*/
inline
serialize
::
OutputArchive
&
operator
<<
(
serialize
::
OutputArchive
&
archive
,
const
graphbolt
::
sampling
::
CSCSamplingGraph
&
graph
);
}
// namespace torch
#endif // GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
#endif // GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
graphbolt/include/graphbolt/serialize.h
View file @
72d16f78
...
@@ -12,134 +12,69 @@
...
@@ -12,134 +12,69 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
namespace
graphbolt
{
#include "csc_sampling_graph.h"
namespace
utils
{
/**
/**
* @brief Utility function to write to archive.
* @brief Overload stream operator to enable `torch::save()` and `torch.load()`
* @param archive Output archive.
* for CSCSamplingGraph.
* @param key Key name used in saving.
* @param data Data that could be constructed as `torch::IValue`.
*/
*/
template
<
typename
DataT
>
namespace
torch
{
void
write_to_archive
(
torch
::
serialize
::
OutputArchive
&
archive
,
const
std
::
string
&
key
,
const
DataT
&
data
)
{
archive
.
write
(
key
,
data
);
}
/**
/**
* @brief Specialization utility function to save string vector.
* @brief Overload input stream operator for CSCSamplingGraph deserialization.
* @param archive Output archive.
* @param archive Input stream for deserializing.
* @param key Key name used in saving.
* @param graph CSCSamplingGraph.
* @param data Vector of string.
*
* @return archive
*/
*/
template
<
>
inline
serialize
::
InputArchive
&
operator
>>
(
void
write_to_archive
<
std
::
vector
<
std
::
string
>>
(
serialize
::
InputArchive
&
archive
,
torch
::
serialize
::
OutputArchive
&
archive
,
const
std
::
string
&
key
,
graphbolt
::
sampling
::
CSCSamplingGraph
&
graph
);
const
std
::
vector
<
std
::
string
>&
data
)
{
archive
.
write
(
key
+
"/size"
,
torch
::
tensor
(
static_cast
<
int64_t
>
(
data
.
size
())));
for
(
const
auto
index
:
c10
::
irange
(
data
.
size
()))
{
archive
.
write
(
key
+
"/"
+
std
::
to_string
(
index
),
data
[
index
]);
}
}
/**
/**
* @brief Utility function to read from archive.
* @brief Overload output stream operator for CSCSamplingGraph serialization.
* @param archive Input archive.
* @param archive Output stream for serializing.
* @param key Key name used in reading.
* @param graph CSCSamplingGraph.
* @param data Data that could be constructed as `torch::IValue`.
*
* @return archive
*/
*/
template
<
typename
DataT
=
torch
::
IValue
>
inline
serialize
::
OutputArchive
&
operator
<<
(
void
read_from_archive
(
serialize
::
OutputArchive
&
archive
,
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
,
const
graphbolt
::
sampling
::
CSCSamplingGraph
&
graph
);
DataT
&
data
)
{
archive
.
read
(
key
,
data
);
}
/**
}
// namespace torch
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `bool`.
*/
template
<
>
void
read_from_archive
<
bool
>
(
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
,
bool
&
data
)
{
torch
::
IValue
iv_data
;
archive
.
read
(
key
,
iv_data
);
data
=
iv_data
.
toBool
();
}
/**
namespace
graphbolt
{
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `int64_t`.
*/
template
<
>
void
read_from_archive
<
int64_t
>
(
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
,
int64_t
&
data
)
{
torch
::
IValue
iv_data
;
archive
.
read
(
key
,
iv_data
);
data
=
iv_data
.
toInt
();
}
/**
/**
* @brief
Specialization utility function to read from archiv
e.
* @brief
Load CSCSamplingGraph from fil
e.
* @param
archive Input archive
.
* @param
filename File name to read
.
*
@param key Key name used in reading.
*
* @
param data Data that is `std::string`
.
* @
return CSCSamplingGraph
.
*/
*/
template
<
>
c10
::
intrusive_ptr
<
sampling
::
CSCSamplingGraph
>
LoadCSCSamplingGraph
(
void
read_from_archive
<
std
::
string
>
(
const
std
::
string
&
filename
);
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
,
std
::
string
&
data
)
{
torch
::
IValue
iv_data
;
archive
.
read
(
key
,
iv_data
);
data
=
iv_data
.
toString
();
}
/**
/**
* @brief S
pecialization utility function to read from archiv
e.
* @brief S
ave CSCSamplingGraph to fil
e.
* @param
archive Input archi
ve.
* @param
graph CSCSamplingGraph to sa
ve.
* @param
key Key name used in reading
.
* @param
filename File name to save
.
*
@param data Data that is `torch::Tensor`.
*
*/
*/
template
<
>
void
SaveCSCSamplingGraph
(
void
read_from_archive
<
torch
::
Tensor
>
(
c10
::
intrusive_ptr
<
sampling
::
CSCSamplingGraph
>
graph
,
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
,
const
std
::
string
&
filename
);
torch
::
Tensor
&
data
)
{
torch
::
IValue
iv_data
;
archive
.
read
(
key
,
iv_data
);
data
=
iv_data
.
toTensor
();
}
/**
/**
* @brief Specialization utility function to read to string vector.
* @brief Read data from archive.
* @param archive Output archive.
* @param archive Input archive.
* @param key Key name used in saving.
* @param key Key name of data.
* @param data Vector of string.
*
* @return data.
*/
*/
template
<
>
torch
::
IValue
read_from_archive
(
void
read_from_archive
<
std
::
vector
<
std
::
string
>>
(
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
);
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
,
std
::
vector
<
std
::
string
>&
data
)
{
int64_t
size
=
0
;
read_from_archive
<
int64_t
>
(
archive
,
key
+
"/size"
,
size
);
data
.
resize
(
static_cast
<
size_t
>
(
size
));
std
::
string
element
;
for
(
int64_t
index
=
0
;
index
<
size
;
++
index
)
{
read_from_archive
<
std
::
string
>
(
archive
,
key
+
"/"
+
std
::
to_string
(
index
),
element
);
data
[
index
]
=
element
;
}
}
}
// namespace utils
}
// namespace graphbolt
}
// namespace graphbolt
#endif // GRAPHBOLT_SERIALIZE_H_
#endif // GRAPHBOLT_SERIALIZE_H_
graphbolt/src/csc_sampling_graph.cc
View file @
72d16f78
...
@@ -11,26 +11,27 @@ namespace graphbolt {
...
@@ -11,26 +11,27 @@ namespace graphbolt {
namespace
sampling
{
namespace
sampling
{
void
HeteroInfo
::
Load
(
torch
::
serialize
::
InputArchive
&
archive
)
{
void
HeteroInfo
::
Load
(
torch
::
serialize
::
InputArchive
&
archive
)
{
int64_t
magic_num
=
0x0
;
const
int64_t
magic_num
=
utils
::
read_from_archive
(
archive
,
"HeteroInfo/magic_num"
,
magic_num
);
read_from_archive
(
archive
,
"HeteroInfo/magic_num"
).
toInt
(
);
TORCH_CHECK
(
TORCH_CHECK
(
magic_num
==
kHeteroInfoSerializeMagic
,
magic_num
==
kHeteroInfoSerializeMagic
,
"Magic numbers mismatch when loading HeteroInfo."
);
"Magic numbers mismatch when loading HeteroInfo."
);
utils
::
read_from_archive
(
archive
,
"HeteroInfo/node_types"
,
node_types
);
node_types
=
read_from_archive
(
archive
,
"HeteroInfo/node_types"
)
utils
::
read_from_archive
(
archive
,
"HeteroInfo/edge_types"
,
edge_types
);
.
to
<
decltype
(
node_types
)
>
();
utils
::
read_from_archive
(
edge_types
=
read_from_archive
(
archive
,
"HeteroInfo/edge_types"
)
archive
,
"HeteroInfo/node_type_offset"
,
node_type_offset
);
.
to
<
decltype
(
edge_types
)
>
();
utils
::
read_from_archive
(
archive
,
"HeteroInfo/type_per_edge"
,
type_per_edge
);
node_type_offset
=
read_from_archive
(
archive
,
"HeteroInfo/node_type_offset"
).
toTensor
();
type_per_edge
=
read_from_archive
(
archive
,
"HeteroInfo/type_per_edge"
).
toTensor
();
}
}
void
HeteroInfo
::
Save
(
torch
::
serialize
::
OutputArchive
&
archive
)
const
{
void
HeteroInfo
::
Save
(
torch
::
serialize
::
OutputArchive
&
archive
)
const
{
utils
::
write_to_archive
(
archive
.
write
(
"HeteroInfo/magic_num"
,
kHeteroInfoSerializeMagic
);
archive
,
"HeteroInfo/magic_num"
,
kHeteroInfoSerializeMagic
);
archive
.
write
(
"HeteroInfo/node_types"
,
node_types
);
utils
::
write_to_archive
(
archive
,
"HeteroInfo/node_types"
,
node_types
);
archive
.
write
(
"HeteroInfo/edge_types"
,
edge_types
);
utils
::
write_to_archive
(
archive
,
"HeteroInfo/edge_types"
,
edge_types
);
archive
.
write
(
"HeteroInfo/node_type_offset"
,
node_type_offset
);
utils
::
write_to_archive
(
archive
.
write
(
"HeteroInfo/type_per_edge"
,
type_per_edge
);
archive
,
"HeteroInfo/node_type_offset"
,
node_type_offset
);
utils
::
write_to_archive
(
archive
,
"HeteroInfo/type_per_edge"
,
type_per_edge
);
}
}
CSCSamplingGraph
::
CSCSamplingGraph
(
CSCSamplingGraph
::
CSCSamplingGraph
(
...
@@ -73,17 +74,16 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSCWithHeteroInfo(
...
@@ -73,17 +74,16 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSCWithHeteroInfo(
}
}
void
CSCSamplingGraph
::
Load
(
torch
::
serialize
::
InputArchive
&
archive
)
{
void
CSCSamplingGraph
::
Load
(
torch
::
serialize
::
InputArchive
&
archive
)
{
int64_t
magic_num
=
0x0
;
const
int64_t
magic_num
=
utils
::
read_from_archive
(
archive
,
"CSCSamplingGraph/magic_num"
,
magic_num
);
read_from_archive
(
archive
,
"CSCSamplingGraph/magic_num"
).
toInt
(
);
TORCH_CHECK
(
TORCH_CHECK
(
magic_num
==
kCSCSamplingGraphSerializeMagic
,
magic_num
==
kCSCSamplingGraphSerializeMagic
,
"Magic numbers mismatch when loading CSCSamplingGraph."
);
"Magic numbers mismatch when loading CSCSamplingGraph."
);
utils
::
read_from_archive
(
archive
,
"CSCSamplingGraph/num_nodes"
,
num_nodes_
);
num_nodes_
=
read_from_archive
(
archive
,
"CSCSamplingGraph/num_nodes"
).
toInt
();
utils
::
read_from_archive
(
archive
,
"CSCSamplingGraph/indptr"
,
indptr_
);
indptr_
=
read_from_archive
(
archive
,
"CSCSamplingGraph/indptr"
).
toTensor
();
utils
::
read_from_archive
(
archive
,
"CSCSamplingGraph/indices"
,
indices_
);
indices_
=
read_from_archive
(
archive
,
"CSCSamplingGraph/indices"
).
toTensor
();
bool
is_heterogeneous
=
false
;
const
bool
is_heterogeneous
=
utils
::
read_from_archive
(
read_from_archive
(
archive
,
"CSCSamplingGraph/is_hetero"
).
toBool
();
archive
,
"CSCSamplingGraph/is_hetero"
,
is_heterogeneous
);
if
(
is_heterogeneous
)
{
if
(
is_heterogeneous
)
{
hetero_info_
=
std
::
make_shared
<
HeteroInfo
>
();
hetero_info_
=
std
::
make_shared
<
HeteroInfo
>
();
hetero_info_
->
Load
(
archive
);
hetero_info_
->
Load
(
archive
);
...
@@ -91,13 +91,11 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
...
@@ -91,13 +91,11 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
}
}
void
CSCSamplingGraph
::
Save
(
torch
::
serialize
::
OutputArchive
&
archive
)
const
{
void
CSCSamplingGraph
::
Save
(
torch
::
serialize
::
OutputArchive
&
archive
)
const
{
archive
.
write
(
archive
.
write
(
"CSCSamplingGraph/magic_num"
,
kCSCSamplingGraphSerializeMagic
);
"CSCSamplingGraph/magic_num"
,
archive
.
write
(
"CSCSamplingGraph/num_nodes"
,
num_nodes_
);
torch
::
IValue
(
kCSCSamplingGraphSerializeMagic
));
archive
.
write
(
"CSCSamplingGraph/num_nodes"
,
torch
::
IValue
(
num_nodes_
));
archive
.
write
(
"CSCSamplingGraph/indptr"
,
indptr_
);
archive
.
write
(
"CSCSamplingGraph/indptr"
,
indptr_
);
archive
.
write
(
"CSCSamplingGraph/indices"
,
indices_
);
archive
.
write
(
"CSCSamplingGraph/indices"
,
indices_
);
archive
.
write
(
"CSCSamplingGraph/is_hetero"
,
torch
::
IValue
(
IsHeterogeneous
())
)
;
archive
.
write
(
"CSCSamplingGraph/is_hetero"
,
IsHeterogeneous
());
if
(
IsHeterogeneous
())
{
if
(
IsHeterogeneous
())
{
hetero_info_
->
Save
(
archive
);
hetero_info_
->
Save
(
archive
);
}
}
...
@@ -105,21 +103,3 @@ void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
...
@@ -105,21 +103,3 @@ void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
}
// namespace sampling
}
// namespace sampling
}
// namespace graphbolt
}
// namespace graphbolt
namespace
torch
{
serialize
::
InputArchive
&
operator
>>
(
serialize
::
InputArchive
&
archive
,
graphbolt
::
sampling
::
CSCSamplingGraph
&
graph
)
{
graph
.
Load
(
archive
);
return
archive
;
}
serialize
::
OutputArchive
&
operator
<<
(
serialize
::
OutputArchive
&
archive
,
const
graphbolt
::
sampling
::
CSCSamplingGraph
&
graph
)
{
graph
.
Save
(
archive
);
return
archive
;
}
}
// namespace torch
graphbolt/src/serialize.cc
0 → 100644
View file @
72d16f78
/**
* Copyright (c) 2023 by Contributors
* @file graphbolt/src/serialize.cc
* @brief Source file of serialize.
*/
#include <graphbolt/serialize.h>
namespace
torch
{
serialize
::
InputArchive
&
operator
>>
(
serialize
::
InputArchive
&
archive
,
graphbolt
::
sampling
::
CSCSamplingGraph
&
graph
)
{
graph
.
Load
(
archive
);
return
archive
;
}
serialize
::
OutputArchive
&
operator
<<
(
serialize
::
OutputArchive
&
archive
,
const
graphbolt
::
sampling
::
CSCSamplingGraph
&
graph
)
{
graph
.
Save
(
archive
);
return
archive
;
}
}
// namespace torch
namespace
graphbolt
{
c10
::
intrusive_ptr
<
sampling
::
CSCSamplingGraph
>
LoadCSCSamplingGraph
(
const
std
::
string
&
filename
)
{
auto
&&
graph
=
c10
::
make_intrusive
<
sampling
::
CSCSamplingGraph
>
();
torch
::
load
(
*
graph
,
filename
);
return
graph
;
}
void
SaveCSCSamplingGraph
(
c10
::
intrusive_ptr
<
sampling
::
CSCSamplingGraph
>
graph
,
const
std
::
string
&
filename
)
{
torch
::
save
(
*
graph
,
filename
);
}
torch
::
IValue
read_from_archive
(
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
)
{
torch
::
IValue
data
;
archive
.
read
(
key
,
data
);
return
data
;
}
}
// namespace graphbolt
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment