Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a5e5f11a
"dgl_sparse/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "a23b490df468cf86c94c1d05ccb864beea3ec2a8"
Unverified
Commit
a5e5f11a
authored
Dec 17, 2023
by
Rhett Ying
Committed by
GitHub
Dec 17, 2023
Browse files
[GraphBolt] de-duplicate code for reading data from achive (#6761)
parent
e181ef15
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
43 additions
and
95 deletions
+43
-95
graphbolt/include/graphbolt/serialize.h
graphbolt/include/graphbolt/serialize.h
+8
-3
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+28
-73
graphbolt/src/serialize.cc
graphbolt/src/serialize.cc
+0
-11
graphbolt/src/shared_memory_helper.cc
graphbolt/src/shared_memory_helper.cc
+7
-8
No files found.
graphbolt/include/graphbolt/serialize.h
View file @
a5e5f11a
...
@@ -57,14 +57,19 @@ inline serialize::OutputArchive& operator<<(
...
@@ -57,14 +57,19 @@ inline serialize::OutputArchive& operator<<(
namespace
graphbolt
{
namespace
graphbolt
{
/**
/**
* @brief Read data from archive.
* @brief Read data from archive
and format to specified type
.
* @param archive Input archive.
* @param archive Input archive.
* @param key Key name of data.
* @param key Key name of data.
*
*
* @return data.
* @return data.
*/
*/
torch
::
IValue
read_from_archive
(
template
<
typename
T
>
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
);
T
read_from_archive
(
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
)
{
torch
::
IValue
data
;
archive
.
read
(
key
,
data
);
return
data
.
to
<
T
>
();
}
}
// namespace graphbolt
}
// namespace graphbolt
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
a5e5f11a
...
@@ -109,91 +109,46 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(
...
@@ -109,91 +109,46 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(
void
FusedCSCSamplingGraph
::
Load
(
torch
::
serialize
::
InputArchive
&
archive
)
{
void
FusedCSCSamplingGraph
::
Load
(
torch
::
serialize
::
InputArchive
&
archive
)
{
const
int64_t
magic_num
=
const
int64_t
magic_num
=
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/magic_num"
)
.
toInt
()
;
read_from_archive
<
int64_t
>
(
archive
,
"FusedCSCSamplingGraph/magic_num"
);
TORCH_CHECK
(
TORCH_CHECK
(
magic_num
==
kCSCSamplingGraphSerializeMagic
,
magic_num
==
kCSCSamplingGraphSerializeMagic
,
"Magic numbers mismatch when loading FusedCSCSamplingGraph."
);
"Magic numbers mismatch when loading FusedCSCSamplingGraph."
);
indptr_
=
indptr_
=
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/indptr"
).
toTensor
();
read_from_archive
<
torch
::
Tensor
>
(
archive
,
"FusedCSCSamplingGraph/indptr"
);
indices_
=
indices_
=
read_from_archive
<
torch
::
Tensor
>
(
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/indices"
).
toTensor
();
archive
,
"FusedCSCSamplingGraph/indices"
);
if
(
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/has_node_type_offset"
)
if
(
read_from_archive
<
bool
>
(
.
toBool
())
{
archive
,
"FusedCSCSamplingGraph/has_node_type_offset"
))
{
node_type_offset_
=
node_type_offset_
=
read_from_archive
<
torch
::
Tensor
>
(
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/node_type_offset"
)
archive
,
"FusedCSCSamplingGraph/node_type_offset"
);
.
toTensor
();
}
}
if
(
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/has_type_per_edge"
)
if
(
read_from_archive
<
bool
>
(
.
toBool
())
{
archive
,
"FusedCSCSamplingGraph/has_type_per_edge"
))
{
type_per_edge_
=
type_per_edge_
=
read_from_archive
<
torch
::
Tensor
>
(
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/type_per_edge"
)
archive
,
"FusedCSCSamplingGraph/type_per_edge"
);
.
toTensor
();
}
}
if
(
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/has_node_type_to_id"
)
if
(
read_from_archive
<
bool
>
(
.
toBool
())
{
archive
,
"FusedCSCSamplingGraph/has_node_type_to_id"
))
{
torch
::
Dict
<
torch
::
IValue
,
torch
::
IValue
>
generic_dict
=
node_type_to_id_
=
read_from_archive
<
NodeTypeToIDMap
>
(
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/node_type_to_id"
)
archive
,
"FusedCSCSamplingGraph/node_type_to_id"
);
.
toGenericDict
();
NodeTypeToIDMap
node_type_to_id
;
for
(
const
auto
&
pair
:
generic_dict
)
{
std
::
string
key
=
pair
.
key
().
toStringRef
();
int64_t
value
=
pair
.
value
().
toInt
();
node_type_to_id
.
insert
(
std
::
move
(
key
),
value
);
}
node_type_to_id_
=
std
::
move
(
node_type_to_id
);
}
}
if
(
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/has_edge_type_to_id"
)
if
(
read_from_archive
<
bool
>
(
.
toBool
())
{
archive
,
"FusedCSCSamplingGraph/has_edge_type_to_id"
))
{
torch
::
Dict
<
torch
::
IValue
,
torch
::
IValue
>
generic_dict
=
edge_type_to_id_
=
read_from_archive
<
EdgeTypeToIDMap
>
(
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/edge_type_to_id"
)
archive
,
"FusedCSCSamplingGraph/edge_type_to_id"
);
.
toGenericDict
();
EdgeTypeToIDMap
edge_type_to_id
;
for
(
const
auto
&
pair
:
generic_dict
)
{
std
::
string
key
=
pair
.
key
().
toStringRef
();
int64_t
value
=
pair
.
value
().
toInt
();
edge_type_to_id
.
insert
(
std
::
move
(
key
),
value
);
}
edge_type_to_id_
=
std
::
move
(
edge_type_to_id
);
}
}
// Optional node attributes.
if
(
read_from_archive
<
bool
>
(
torch
::
IValue
has_node_attributes
;
archive
,
"FusedCSCSamplingGraph/has_node_attributes"
))
{
if
(
archive
.
try_read
(
node_attributes_
=
read_from_archive
<
NodeAttrMap
>
(
"FusedCSCSamplingGraph/has_node_attributes"
,
has_node_attributes
)
&&
archive
,
"FusedCSCSamplingGraph/node_attributes"
);
has_node_attributes
.
toBool
())
{
torch
::
Dict
<
torch
::
IValue
,
torch
::
IValue
>
generic_dict
=
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/node_attributes"
)
.
toGenericDict
();
NodeAttrMap
target_dict
;
for
(
const
auto
&
pair
:
generic_dict
)
{
std
::
string
key
=
pair
.
key
().
toStringRef
();
torch
::
Tensor
value
=
pair
.
value
().
toTensor
();
// Use move to avoid copy.
target_dict
.
insert
(
std
::
move
(
key
),
std
::
move
(
value
));
}
// Same as above.
node_attributes_
=
std
::
move
(
target_dict
);
}
}
if
(
read_from_archive
<
bool
>
(
// Optional edge attributes.
archive
,
"FusedCSCSamplingGraph/has_edge_attributes"
))
{
torch
::
IValue
has_edge_attributes
;
edge_attributes_
=
read_from_archive
<
EdgeAttrMap
>
(
if
(
archive
.
try_read
(
archive
,
"FusedCSCSamplingGraph/edge_attributes"
);
"FusedCSCSamplingGraph/has_edge_attributes"
,
has_edge_attributes
)
&&
has_edge_attributes
.
toBool
())
{
torch
::
Dict
<
torch
::
IValue
,
torch
::
IValue
>
generic_dict
=
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/edge_attributes"
)
.
toGenericDict
();
EdgeAttrMap
target_dict
;
for
(
const
auto
&
pair
:
generic_dict
)
{
std
::
string
key
=
pair
.
key
().
toStringRef
();
torch
::
Tensor
value
=
pair
.
value
().
toTensor
();
// Use move to avoid copy.
target_dict
.
insert
(
std
::
move
(
key
),
std
::
move
(
value
));
}
// Same as above.
edge_attributes_
=
std
::
move
(
target_dict
);
}
}
}
}
...
...
graphbolt/src/serialize.cc
View file @
a5e5f11a
...
@@ -24,14 +24,3 @@ serialize::OutputArchive& operator<<(
...
@@ -24,14 +24,3 @@ serialize::OutputArchive& operator<<(
}
}
}
// namespace torch
}
// namespace torch
namespace
graphbolt
{
torch
::
IValue
read_from_archive
(
torch
::
serialize
::
InputArchive
&
archive
,
const
std
::
string
&
key
)
{
torch
::
IValue
data
;
archive
.
read
(
key
,
data
);
return
data
;
}
}
// namespace graphbolt
graphbolt/src/shared_memory_helper.cc
View file @
a5e5f11a
...
@@ -88,10 +88,10 @@ void SharedMemoryHelper::WriteTorchTensor(
...
@@ -88,10 +88,10 @@ void SharedMemoryHelper::WriteTorchTensor(
torch
::
optional
<
torch
::
Tensor
>
SharedMemoryHelper
::
ReadTorchTensor
()
{
torch
::
optional
<
torch
::
Tensor
>
SharedMemoryHelper
::
ReadTorchTensor
()
{
auto
archive
=
this
->
ReadTorchArchive
();
auto
archive
=
this
->
ReadTorchArchive
();
bool
has_value
=
read_from_archive
(
archive
,
"has_value"
)
.
toBool
()
;
bool
has_value
=
read_from_archive
<
bool
>
(
archive
,
"has_value"
);
if
(
has_value
)
{
if
(
has_value
)
{
auto
shape
=
read_from_archive
(
archive
,
"shape"
)
.
toIntVector
()
;
auto
shape
=
read_from_archive
<
std
::
vector
<
int64_t
>>
(
archive
,
"shape"
);
auto
dtype
=
read_from_archive
(
archive
,
"dtype"
)
.
toScalarType
()
;
auto
dtype
=
read_from_archive
<
torch
::
ScalarType
>
(
archive
,
"dtype"
);
auto
data_ptr
=
this
->
GetCurrentDataPtr
();
auto
data_ptr
=
this
->
GetCurrentDataPtr
();
auto
tensor
=
torch
::
from_blob
(
data_ptr
,
shape
,
dtype
);
auto
tensor
=
torch
::
from_blob
(
data_ptr
,
shape
,
dtype
);
auto
rounded_size
=
GetRoundedSize
(
tensor
.
numel
()
*
tensor
.
element_size
());
auto
rounded_size
=
GetRoundedSize
(
tensor
.
numel
()
*
tensor
.
element_size
());
...
@@ -127,15 +127,14 @@ void SharedMemoryHelper::WriteTorchTensorDict(
...
@@ -127,15 +127,14 @@ void SharedMemoryHelper::WriteTorchTensorDict(
torch
::
optional
<
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>>
torch
::
optional
<
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>>
SharedMemoryHelper
::
ReadTorchTensorDict
()
{
SharedMemoryHelper
::
ReadTorchTensorDict
()
{
auto
archive
=
this
->
ReadTorchArchive
();
auto
archive
=
this
->
ReadTorchArchive
();
if
(
!
read_from_archive
(
archive
,
"has_value"
)
.
toBool
()
)
{
if
(
!
read_from_archive
<
bool
>
(
archive
,
"has_value"
))
{
return
torch
::
nullopt
;
return
torch
::
nullopt
;
}
}
int64_t
num_tensors
=
read_from_archive
(
archive
,
"num_tensors"
)
.
toInt
()
;
int64_t
num_tensors
=
read_from_archive
<
int64_t
>
(
archive
,
"num_tensors"
);
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>
tensor_dict
;
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>
tensor_dict
;
for
(
int64_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
auto
key
=
auto
key
=
read_from_archive
<
std
::
string
>
(
read_from_archive
(
archive
,
std
::
string
(
"key_"
)
+
std
::
to_string
(
i
))
archive
,
std
::
string
(
"key_"
)
+
std
::
to_string
(
i
));
.
toStringRef
();
auto
tensor
=
this
->
ReadTorchTensor
();
auto
tensor
=
this
->
ReadTorchTensor
();
tensor_dict
.
insert
(
key
,
tensor
.
value
());
tensor_dict
.
insert
(
key
,
tensor
.
value
());
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment