Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
77fd747f
Unverified
Commit
77fd747f
authored
May 18, 2023
by
Rhett Ying
Committed by
GitHub
May 18, 2023
Browse files
[GraphBolt] enable load/save for CSCSamplingGraph (#5702)
parent
46af76c3
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
288 additions
and
0 deletions
+288
-0
graphbolt/include/csc_sampling_graph.h
graphbolt/include/csc_sampling_graph.h
+69
-0
graphbolt/include/serialize.h
graphbolt/include/serialize.h
+145
-0
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+74
-0
No files found.
graphbolt/include/csc_sampling_graph.h
View file @
77fd747f
...
@@ -49,6 +49,9 @@ struct HeteroInfo {
...
@@ -49,6 +49,9 @@ struct HeteroInfo {
node_type_offset
(
node_type_offset
),
node_type_offset
(
node_type_offset
),
type_per_edge
(
type_per_edge
)
{}
type_per_edge
(
type_per_edge
)
{}
/** @brief Default constructor. */
HeteroInfo
()
=
default
;
/** @brief List of node types in the graph.*/
/** @brief List of node types in the graph.*/
StringList
node_types
;
StringList
node_types
;
...
@@ -66,6 +69,24 @@ struct HeteroInfo {
...
@@ -66,6 +69,24 @@ struct HeteroInfo {
* edge_types. The length of it is equal to the number of edges.
* edge_types. The length of it is equal to the number of edges.
*/
*/
torch
::
Tensor
type_per_edge
;
torch
::
Tensor
type_per_edge
;
/**
** @brief Magic number to indicate Hetero info version in serialize/
** deserialize stages.
**/
static
constexpr
int64_t
kHeteroInfoSerializeMagic
=
0xDD2E60F0F6B4A129
;
/**
** @brief Load hetero info from stream.
** @param archive Input stream for deserializing.
**/
void
Load
(
torch
::
serialize
::
InputArchive
&
archive
);
/**
** @brief Save hetero info to stream.
** @param archive Output stream for serializing.
**/
void
Save
(
torch
::
serialize
::
OutputArchive
&
archive
)
const
;
};
};
/**
/**
...
@@ -148,6 +169,24 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -148,6 +169,24 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
return
hetero_info_
->
type_per_edge
;
return
hetero_info_
->
type_per_edge
;
}
}
/**
** @brief Magic number to indicate graph version in serialize/deserialize
** stage.
**/
static
constexpr
int64_t
kCSCSamplingGraphSerializeMagic
=
0xDD2E60F0F6B4A128
;
/**
** @brief Load graph from stream.
** @param archive Input stream for deserializing.
**/
void
Load
(
torch
::
serialize
::
InputArchive
&
archive
);
/**
** @brief Save graph to stream.
** @param archive Output stream for serializing.
**/
void
Save
(
torch
::
serialize
::
OutputArchive
&
archive
)
const
;
private:
private:
/** @brief The number of nodes of the graph. */
/** @brief The number of nodes of the graph. */
int64_t
num_nodes_
=
0
;
int64_t
num_nodes_
=
0
;
...
@@ -161,3 +200,33 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -161,3 +200,33 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
}
// namespace sampling
}
// namespace sampling
}
// namespace graphbolt
}
// namespace graphbolt
/**
** @brief Overload stream operator to enable `torch::save()` and `torch.load()`
** for CSCSamplingGraph.
**/
namespace
torch
{
/**
** @brief Overload input stream operator for CSCSamplingGraph deserialization.
** @param archive Input stream for deserializing.
** @param graph CSCSamplingGraph.
**
** @return archive
**/
inline
serialize
::
InputArchive
&
operator
>>
(
serialize
::
InputArchive
&
archive
,
graphbolt
::
sampling
::
CSCSamplingGraph
&
graph
);
/**
* @brief Overload output stream operator for CSCSamplingGraph serialization.
* @param archive Output stream for serializing.
* @param graph CSCSamplingGraph.
*
* @return archive
*/
inline
serialize
::
OutputArchive
&
operator
<<
(
serialize
::
OutputArchive
&
archive
,
const
graphbolt
::
sampling
::
CSCSamplingGraph
&
graph
);
}
// namespace torch
graphbolt/include/serialize.h
0 → 100644
View file @
77fd747f
/**
* Copyright (c) 2023 by Contributors
* @file graphbolt/include/serialize.h
* @brief Utility functions for serialize and deserialize.
*/
#ifndef GRAPHBOLT_INCLUDE_SERIALIZE_H_
#define GRAPHBOLT_INCLUDE_SERIALIZE_H_
#include <torch/torch.h>
#include <string>
#include <vector>
namespace
graphbolt
{
namespace
utils
{
/**
* @brief Utility function to write to archive.
* @param archive Output archive.
* @param key Key name used in saving.
* @param data Data that could be constructed as `torch::IValue`.
**/
template
<
typename
DataT
>
void
write_to_archive
(
torch
::
serialize
::
OutputArchive
&
archive
,
const
std
::
string
&
key
,
const
DataT
&
data
)
{
archive
.
write
(
key
,
data
);
}
/**
* @brief Specialization utility function to save string vector.
* @param archive Output archive.
* @param key Key name used in saving.
* @param data Vector of string.
**/
template
<
>
void
write_to_archive
<
std
::
vector
<
std
::
string
>>
(
torch
::
serialize
::
OutputArchive
&
archive
,
const
std
::
string
&
key
,
const
std
::
vector
<
std
::
string
>&
data
)
{
archive
.
write
(
key
+
"/size"
,
torch
::
tensor
(
static_cast
<
int64_t
>
(
data
.
size
())));
for
(
const
auto
index
:
c10
::
irange
(
data
.
size
()))
{
archive
.
write
(
key
+
"/"
+
std
::
to_string
(
index
),
data
[
index
]);
}
}
/**
* @brief Utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that could be constructed as `torch::IValue`.
**/
template
<
typename
DataT
=
torch
::
IValue
>
void
read_from_archive
(
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
,
DataT
&
data
)
{
archive
.
read
(
key
,
data
);
}
/**
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `bool`.
**/
template
<
>
void
read_from_archive
<
bool
>
(
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
,
bool
&
data
)
{
torch
::
IValue
iv_data
;
archive
.
read
(
key
,
iv_data
);
data
=
iv_data
.
toBool
();
}
/**
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `int64_t`.
**/
template
<
>
void
read_from_archive
<
int64_t
>
(
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
,
int64_t
&
data
)
{
torch
::
IValue
iv_data
;
archive
.
read
(
key
,
iv_data
);
data
=
iv_data
.
toInt
();
}
/**
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `std::string`.
**/
template
<
>
void
read_from_archive
<
std
::
string
>
(
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
,
std
::
string
&
data
)
{
torch
::
IValue
iv_data
;
archive
.
read
(
key
,
iv_data
);
data
=
iv_data
.
toString
();
}
/**
* @brief Specialization utility function to read from archive.
* @param archive Input archive.
* @param key Key name used in reading.
* @param data Data that is `torch::Tensor`.
**/
template
<
>
void
read_from_archive
<
torch
::
Tensor
>
(
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
,
torch
::
Tensor
&
data
)
{
torch
::
IValue
iv_data
;
archive
.
read
(
key
,
iv_data
);
data
=
iv_data
.
toTensor
();
}
/**
* @brief Specialization utility function to read to string vector.
* @param archive Output archive.
* @param key Key name used in saving.
* @param data Vector of string.
**/
template
<
>
void
read_from_archive
<
std
::
vector
<
std
::
string
>>
(
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
,
std
::
vector
<
std
::
string
>&
data
)
{
int64_t
size
=
0
;
read_from_archive
<
int64_t
>
(
archive
,
key
+
"/size"
,
size
);
data
.
resize
(
static_cast
<
size_t
>
(
size
));
std
::
string
element
;
for
(
int64_t
index
=
0
;
index
<
size
;
++
index
)
{
read_from_archive
<
std
::
string
>
(
archive
,
key
+
"/"
+
std
::
to_string
(
index
),
element
);
data
[
index
]
=
element
;
}
}
}
// namespace utils
}
// namespace graphbolt
#endif // GRAPHBOLT_INCLUDE_SERIALIZE_H_
graphbolt/src/csc_sampling_graph.cc
View file @
77fd747f
...
@@ -6,9 +6,34 @@
...
@@ -6,9 +6,34 @@
#include "csc_sampling_graph.h"
#include "csc_sampling_graph.h"
#include "serialize.h"
namespace
graphbolt
{
namespace
graphbolt
{
namespace
sampling
{
namespace
sampling
{
void
HeteroInfo
::
Load
(
torch
::
serialize
::
InputArchive
&
archive
)
{
int64_t
magic_num
=
0x0
;
utils
::
read_from_archive
(
archive
,
"HeteroInfo/magic_num"
,
magic_num
);
TORCH_CHECK
(
magic_num
==
kHeteroInfoSerializeMagic
,
"Magic numbers mismatch when loading HeteroInfo."
);
utils
::
read_from_archive
(
archive
,
"HeteroInfo/node_types"
,
node_types
);
utils
::
read_from_archive
(
archive
,
"HeteroInfo/edge_types"
,
edge_types
);
utils
::
read_from_archive
(
archive
,
"HeteroInfo/node_type_offset"
,
node_type_offset
);
utils
::
read_from_archive
(
archive
,
"HeteroInfo/type_per_edge"
,
type_per_edge
);
}
void
HeteroInfo
::
Save
(
torch
::
serialize
::
OutputArchive
&
archive
)
const
{
utils
::
write_to_archive
(
archive
,
"HeteroInfo/magic_num"
,
kHeteroInfoSerializeMagic
);
utils
::
write_to_archive
(
archive
,
"HeteroInfo/node_types"
,
node_types
);
utils
::
write_to_archive
(
archive
,
"HeteroInfo/edge_types"
,
edge_types
);
utils
::
write_to_archive
(
archive
,
"HeteroInfo/node_type_offset"
,
node_type_offset
);
utils
::
write_to_archive
(
archive
,
"HeteroInfo/type_per_edge"
,
type_per_edge
);
}
CSCSamplingGraph
::
CSCSamplingGraph
(
CSCSamplingGraph
::
CSCSamplingGraph
(
int64_t
num_nodes
,
torch
::
Tensor
&
indptr
,
torch
::
Tensor
&
indices
,
int64_t
num_nodes
,
torch
::
Tensor
&
indptr
,
torch
::
Tensor
&
indices
,
const
std
::
shared_ptr
<
HeteroInfo
>&
hetero_info
)
const
std
::
shared_ptr
<
HeteroInfo
>&
hetero_info
)
...
@@ -48,5 +73,54 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSCWithHeteroInfo(
...
@@ -48,5 +73,54 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSCWithHeteroInfo(
num_nodes
,
indptr
,
indices
,
hetero_info
);
num_nodes
,
indptr
,
indices
,
hetero_info
);
}
}
void
CSCSamplingGraph
::
Load
(
torch
::
serialize
::
InputArchive
&
archive
)
{
int64_t
magic_num
=
0x0
;
utils
::
read_from_archive
(
archive
,
"CSCSamplingGraph/magic_num"
,
magic_num
);
TORCH_CHECK
(
magic_num
==
kCSCSamplingGraphSerializeMagic
,
"Magic numbers mismatch when loading CSCSamplingGraph."
);
utils
::
read_from_archive
(
archive
,
"CSCSamplingGraph/num_nodes"
,
num_nodes_
);
utils
::
read_from_archive
(
archive
,
"CSCSamplingGraph/indptr"
,
indptr_
);
utils
::
read_from_archive
(
archive
,
"CSCSamplingGraph/indices"
,
indices_
);
bool
is_heterogeneous
=
false
;
utils
::
read_from_archive
(
archive
,
"CSCSamplingGraph/is_hetero"
,
is_heterogeneous
);
if
(
is_heterogeneous
)
{
hetero_info_
=
std
::
make_shared
<
HeteroInfo
>
();
hetero_info_
->
Load
(
archive
);
}
}
void
CSCSamplingGraph
::
Save
(
torch
::
serialize
::
OutputArchive
&
archive
)
const
{
archive
.
write
(
"CSCSamplingGraph/magic_num"
,
torch
::
IValue
(
kCSCSamplingGraphSerializeMagic
));
archive
.
write
(
"CSCSamplingGraph/num_nodes"
,
torch
::
IValue
(
num_nodes_
));
archive
.
write
(
"CSCSamplingGraph/indptr"
,
indptr_
);
archive
.
write
(
"CSCSamplingGraph/indices"
,
indices_
);
archive
.
write
(
"CSCSamplingGraph/is_hetero"
,
torch
::
IValue
(
IsHeterogeneous
()));
if
(
IsHeterogeneous
())
{
hetero_info_
->
Save
(
archive
);
}
}
}
// namespace sampling
}
// namespace sampling
}
// namespace graphbolt
}
// namespace graphbolt
namespace
torch
{
serialize
::
InputArchive
&
operator
>>
(
serialize
::
InputArchive
&
archive
,
graphbolt
::
sampling
::
CSCSamplingGraph
&
graph
)
{
graph
.
Load
(
archive
);
return
archive
;
}
serialize
::
OutputArchive
&
operator
<<
(
serialize
::
OutputArchive
&
archive
,
const
graphbolt
::
sampling
::
CSCSamplingGraph
&
graph
)
{
graph
.
Save
(
archive
);
return
archive
;
}
}
// namespace torch
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment