Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
8a852530
Unverified
Commit
8a852530
authored
Dec 06, 2023
by
Rhett Ying
Committed by
GitHub
Dec 06, 2023
Browse files
[GraphBolt] support node/edge_type_to_id in shared memory (#6693)
parent
a19d5f3b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
3 deletions
+33
-3
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
+2
-2
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+31
-1
No files found.
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
View file @
8a852530
...
@@ -378,9 +378,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -378,9 +378,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @brief Maximum number of bytes used to serialize the metadata of the
* @brief Maximum number of bytes used to serialize the metadata of the
* member tensors, including tensor shape and dtype. The constant is estimated
* member tensors, including tensor shape and dtype. The constant is estimated
* by multiplying the number of tensors in this class and the maximum number
* by multiplying the number of tensors in this class and the maximum number
* of bytes used to serialize the metadata of a tensor (
4
* 8192 for now).
* of bytes used to serialize the metadata of a tensor (
10
* 8192 for now).
*/
*/
static
constexpr
int64_t
SERIALIZED_METAINFO_SIZE_MAX
=
32768
;
static
constexpr
int64_t
SERIALIZED_METAINFO_SIZE_MAX
=
10
*
81920
;
/**
/**
* @brief Shared memory used to hold the tensor metadata and data of this
* @brief Shared memory used to hold the tensor metadata and data of this
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
8a852530
...
@@ -19,6 +19,32 @@
...
@@ -19,6 +19,32 @@
#include "./random.h"
#include "./random.h"
#include "./shared_memory_helper.h"
#include "./shared_memory_helper.h"
namespace
{
torch
::
optional
<
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>>
TensorizeDict
(
const
torch
::
optional
<
torch
::
Dict
<
std
::
string
,
int64_t
>>&
dict
)
{
if
(
!
dict
.
has_value
())
{
return
torch
::
nullopt
;
}
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>
result
;
for
(
const
auto
&
pair
:
dict
.
value
())
{
result
.
insert
(
pair
.
key
(),
torch
::
tensor
(
pair
.
value
(),
torch
::
kInt64
));
}
return
result
;
}
torch
::
optional
<
torch
::
Dict
<
std
::
string
,
int64_t
>>
DetensorizeDict
(
const
torch
::
optional
<
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>>&
dict
)
{
if
(
!
dict
.
has_value
())
{
return
torch
::
nullopt
;
}
torch
::
Dict
<
std
::
string
,
int64_t
>
result
;
for
(
const
auto
&
pair
:
dict
.
value
())
{
result
.
insert
(
pair
.
key
(),
pair
.
value
().
item
<
int64_t
>
());
}
return
result
;
}
}
// namespace
namespace
graphbolt
{
namespace
graphbolt
{
namespace
sampling
{
namespace
sampling
{
...
@@ -556,10 +582,12 @@ BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) {
...
@@ -556,10 +582,12 @@ BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) {
auto
indices
=
helper
.
ReadTorchTensor
();
auto
indices
=
helper
.
ReadTorchTensor
();
auto
node_type_offset
=
helper
.
ReadTorchTensor
();
auto
node_type_offset
=
helper
.
ReadTorchTensor
();
auto
type_per_edge
=
helper
.
ReadTorchTensor
();
auto
type_per_edge
=
helper
.
ReadTorchTensor
();
auto
node_type_to_id
=
DetensorizeDict
(
helper
.
ReadTorchTensorDict
());
auto
edge_type_to_id
=
DetensorizeDict
(
helper
.
ReadTorchTensorDict
());
auto
edge_attributes
=
helper
.
ReadTorchTensorDict
();
auto
edge_attributes
=
helper
.
ReadTorchTensorDict
();
auto
graph
=
c10
::
make_intrusive
<
FusedCSCSamplingGraph
>
(
auto
graph
=
c10
::
make_intrusive
<
FusedCSCSamplingGraph
>
(
indptr
.
value
(),
indices
.
value
(),
node_type_offset
,
type_per_edge
,
indptr
.
value
(),
indices
.
value
(),
node_type_offset
,
type_per_edge
,
torch
::
nullopt
,
torch
::
nullopt
,
edge_attributes
);
node_type_to_id
,
edge_type_to_id
,
edge_attributes
);
auto
shared_memory
=
helper
.
ReleaseSharedMemory
();
auto
shared_memory
=
helper
.
ReleaseSharedMemory
();
graph
->
HoldSharedMemoryObject
(
graph
->
HoldSharedMemoryObject
(
std
::
move
(
shared_memory
.
first
),
std
::
move
(
shared_memory
.
second
));
std
::
move
(
shared_memory
.
first
),
std
::
move
(
shared_memory
.
second
));
...
@@ -574,6 +602,8 @@ FusedCSCSamplingGraph::CopyToSharedMemory(
...
@@ -574,6 +602,8 @@ FusedCSCSamplingGraph::CopyToSharedMemory(
helper
.
WriteTorchTensor
(
indices_
);
helper
.
WriteTorchTensor
(
indices_
);
helper
.
WriteTorchTensor
(
node_type_offset_
);
helper
.
WriteTorchTensor
(
node_type_offset_
);
helper
.
WriteTorchTensor
(
type_per_edge_
);
helper
.
WriteTorchTensor
(
type_per_edge_
);
helper
.
WriteTorchTensorDict
(
TensorizeDict
(
node_type_to_id_
));
helper
.
WriteTorchTensorDict
(
TensorizeDict
(
edge_type_to_id_
));
helper
.
WriteTorchTensorDict
(
edge_attributes_
);
helper
.
WriteTorchTensorDict
(
edge_attributes_
);
helper
.
Flush
();
helper
.
Flush
();
return
BuildGraphFromSharedMemoryHelper
(
std
::
move
(
helper
));
return
BuildGraphFromSharedMemoryHelper
(
std
::
move
(
helper
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment