Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
aa562f7e
Unverified
Commit
aa562f7e
authored
Sep 04, 2023
by
czkkkkkk
Committed by
GitHub
Sep 04, 2023
Browse files
[Graphbolt] Refactor shared memory utility. (#6198)
parent
df1ea757
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
350 additions
and
206 deletions
+350
-206
graphbolt/include/graphbolt/csc_sampling_graph.h
graphbolt/include/graphbolt/csc_sampling_graph.h
+7
-9
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+22
-17
graphbolt/src/shared_memory_utils.cc
graphbolt/src/shared_memory_utils.cc
+168
-137
graphbolt/src/shared_memory_utils.h
graphbolt/src/shared_memory_utils.h
+109
-39
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
.../python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
+44
-4
No files found.
graphbolt/include/graphbolt/csc_sampling_graph.h
View file @
aa562f7e
...
@@ -31,6 +31,8 @@ struct SamplerArgs<SamplerType::LABOR> {
...
@@ -31,6 +31,8 @@ struct SamplerArgs<SamplerType::LABOR> {
int64_t
num_nodes
;
int64_t
num_nodes
;
};
};
class
SharedMemoryHelper
;
/**
/**
* @brief A sampling oriented csc format graph.
* @brief A sampling oriented csc format graph.
*
*
...
@@ -243,18 +245,14 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -243,18 +245,14 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
PickFn
pick_fn
)
const
;
PickFn
pick_fn
)
const
;
/**
/**
* @brief Build a CSCSamplingGraph from shared memory tensors.
* @brief Build a CSCSamplingGraph from a shared memory helper. This function
*
* takes ownership of the shared memory objects in the helper.
* @param shared_memory_tensors A tuple of two share memory objects holding
* tensor meta information and data respectively, and a vector of optional
* tensors on shared memory.
*
*
* @param shared_memory_helper The shared memory helper.
* @return A new CSCSamplingGraph on shared memory.
* @return A new CSCSamplingGraph on shared memory.
*/
*/
static
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
BuildGraphFromSharedMemoryTensors
(
static
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
BuildGraphFromSharedMemoryHelper
(
std
::
tuple
<
SharedMemoryHelper
&&
shared_memory_helper
);
SharedMemoryPtr
,
SharedMemoryPtr
,
std
::
vector
<
torch
::
optional
<
torch
::
Tensor
>>>&&
shared_memory_tensors
);
/** @brief CSC format index pointer array. */
/** @brief CSC format index pointer array. */
torch
::
Tensor
indptr_
;
torch
::
Tensor
indptr_
;
...
...
graphbolt/src/csc_sampling_graph.cc
View file @
aa562f7e
...
@@ -439,33 +439,38 @@ CSCSamplingGraph::SampleNegativeEdgesUniform(
...
@@ -439,33 +439,38 @@ CSCSamplingGraph::SampleNegativeEdgesUniform(
}
}
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
CSCSamplingGraph
::
BuildGraphFromSharedMemoryTensors
(
CSCSamplingGraph
::
BuildGraphFromSharedMemoryHelper
(
std
::
tuple
<
SharedMemoryHelper
&&
helper
)
{
SharedMemoryPtr
,
SharedMemoryPtr
,
helper
.
InitializeRead
();
std
::
vector
<
torch
::
optional
<
torch
::
Tensor
>>>&&
shared_memory_tensors
)
{
auto
indptr
=
helper
.
ReadTorchTensor
();
auto
&
optional_tensors
=
std
::
get
<
2
>
(
shared_memory_tensors
);
auto
indices
=
helper
.
ReadTorchTensor
();
auto
node_type_offset
=
helper
.
ReadTorchTensor
();
auto
type_per_edge
=
helper
.
ReadTorchTensor
();
auto
edge_attributes
=
helper
.
ReadTorchTensorDict
();
auto
graph
=
c10
::
make_intrusive
<
CSCSamplingGraph
>
(
auto
graph
=
c10
::
make_intrusive
<
CSCSamplingGraph
>
(
optional_tensors
[
0
].
value
(),
optional_tensors
[
1
].
value
()
,
indptr
.
value
(),
indices
.
value
(),
node_type_offset
,
type_per_edge
,
optional_tensors
[
2
],
optional_tensors
[
3
],
torch
::
nullopt
);
edge_attributes
);
graph
->
tensor_meta_shm_
=
std
::
move
(
std
::
get
<
0
>
(
shared_memory_tensors
));
std
::
tie
(
graph
->
tensor_meta_shm_
,
graph
->
tensor_data_shm_
)
=
graph
->
tensor_data_shm_
=
std
::
move
(
std
::
get
<
1
>
(
s
hared
_m
emory
_tensors
)
);
helper
.
ReleaseS
hared
M
emory
(
);
return
graph
;
return
graph
;
}
}
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
CSCSamplingGraph
::
CopyToSharedMemory
(
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
CSCSamplingGraph
::
CopyToSharedMemory
(
const
std
::
string
&
shared_memory_name
)
{
const
std
::
string
&
shared_memory_name
)
{
auto
optional_tensors
=
std
::
vector
<
torch
::
optional
<
torch
::
Tensor
>>
{
SharedMemoryHelper
helper
(
shared_memory_name
,
SERIALIZED_METAINFO_SIZE_MAX
);
indptr_
,
indices_
,
node_type_offset_
,
type_per_edge_
};
helper
.
WriteTorchTensor
(
indptr_
);
auto
shared_memory_tensors
=
CopyTensorsToSharedMemory
(
helper
.
WriteTorchTensor
(
indices_
);
shared_memory_name
,
optional_tensors
,
SERIALIZED_METAINFO_SIZE_MAX
);
helper
.
WriteTorchTensor
(
node_type_offset_
);
return
BuildGraphFromSharedMemoryTensors
(
std
::
move
(
shared_memory_tensors
));
helper
.
WriteTorchTensor
(
type_per_edge_
);
helper
.
WriteTorchTensorDict
(
edge_attributes_
);
helper
.
Flush
();
return
BuildGraphFromSharedMemoryHelper
(
std
::
move
(
helper
));
}
}
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
CSCSamplingGraph
::
LoadFromSharedMemory
(
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
CSCSamplingGraph
::
LoadFromSharedMemory
(
const
std
::
string
&
shared_memory_name
)
{
const
std
::
string
&
shared_memory_name
)
{
auto
shared_memory_tensors
=
LoadTensorsFromSharedMemory
(
SharedMemoryHelper
helper
(
shared_memory_name
,
SERIALIZED_METAINFO_SIZE_MAX
);
shared_memory_name
,
SERIALIZED_METAINFO_SIZE_MAX
);
return
BuildGraphFromSharedMemoryHelper
(
std
::
move
(
helper
));
return
BuildGraphFromSharedMemoryTensors
(
std
::
move
(
shared_memory_tensors
));
}
}
int64_t
NumPick
(
int64_t
NumPick
(
...
...
graphbolt/src/shared_memory_utils.cc
View file @
aa562f7e
...
@@ -18,163 +18,194 @@
...
@@ -18,163 +18,194 @@
namespace
graphbolt
{
namespace
graphbolt
{
namespace
sampling
{
namespace
sampling
{
static
SharedMemoryPtr
CopyTorchArchiveToSharedMemory
(
static
std
::
string
GetSharedMemoryMetadataName
(
const
std
::
string
&
name
)
{
const
std
::
string
&
name
,
int64_t
size
,
return
name
+
"_meta"
;
torch
::
serialize
::
OutputArchive
&
archive
)
{
std
::
stringstream
serialized
;
archive
.
save_to
(
serialized
);
auto
serialized_str
=
serialized
.
str
();
auto
shm
=
std
::
make_unique
<
SharedMemory
>
(
name
);
auto
mem_buf
=
shm
->
Create
(
size
);
// Use the first 8 bytes to store the size of the serialized string.
static_cast
<
int64_t
*>
(
mem_buf
)[
0
]
=
serialized_str
.
size
();
memcpy
(
(
char
*
)
mem_buf
+
sizeof
(
int64_t
),
serialized_str
.
data
(),
serialized_str
.
size
());
return
shm
;
}
}
static
SharedMemoryPtr
LoadTorchArchiveFromSharedMemory
(
static
std
::
string
GetSharedMemoryDataName
(
const
std
::
string
&
name
)
{
const
std
::
string
&
name
,
int64_t
max_meta_size
,
return
name
+
"_data"
;
torch
::
serialize
::
InputArchive
&
archive
)
{
auto
shm
=
std
::
make_unique
<
SharedMemory
>
(
name
);
auto
mem_buf
=
shm
->
Open
(
max_meta_size
);
int64_t
meta_size
=
static_cast
<
int64_t
*>
(
mem_buf
)[
0
];
archive
.
load_from
(
static_cast
<
const
char
*>
(
mem_buf
)
+
sizeof
(
int64_t
),
meta_size
);
return
shm
;
}
}
static
SharedMemoryPtr
CopyTensorsDataToSharedMemory
(
// To avoid unaligned memory access, we round the size of the binary buffer to
const
std
::
string
&
name
,
// the nearest multiple of 8 bytes.
const
std
::
vector
<
torch
::
optional
<
torch
::
Tensor
>>&
tensors
)
{
inline
static
int64_t
GetRoundedSize
(
int64_t
size
)
{
int64_t
memory_size
=
0
;
constexpr
int64_t
ALIGNED_SIZE
=
8
;
for
(
const
auto
&
optional_tensor
:
tensors
)
{
return
(
size
+
ALIGNED_SIZE
-
1
)
/
ALIGNED_SIZE
*
ALIGNED_SIZE
;
if
(
optional_tensor
.
has_value
())
{
}
auto
tensor
=
optional_tensor
.
value
();
memory_size
+=
tensor
.
numel
()
*
tensor
.
element_size
();
SharedMemoryHelper
::
SharedMemoryHelper
(
}
const
std
::
string
&
name
,
int64_t
max_metadata_size
)
}
:
name_
(
name
),
auto
shm
=
std
::
make_unique
<
SharedMemory
>
(
name
);
max_metadata_size_
(
max_metadata_size
),
auto
mem_buf
=
shm
->
Create
(
memory_size
);
metadata_shared_memory_
(
nullptr
),
for
(
auto
optional_tensor
:
tensors
)
{
data_shared_memory_
(
nullptr
),
if
(
optional_tensor
.
has_value
())
{
metadata_offset_
(
0
),
auto
tensor
=
optional_tensor
.
value
().
contiguous
();
data_offset_
(
0
)
{}
int64_t
size
=
tensor
.
numel
()
*
tensor
.
element_size
();
memcpy
(
mem_buf
,
tensor
.
data_ptr
(),
size
);
void
SharedMemoryHelper
::
InitializeRead
()
{
mem_buf
=
static_cast
<
char
*>
(
mem_buf
)
+
size
;
metadata_offset_
=
0
;
}
data_offset_
=
0
;
if
(
metadata_shared_memory_
==
nullptr
)
{
// Reader process opens the shared memory.
metadata_shared_memory_
=
std
::
make_unique
<
SharedMemory
>
(
GetSharedMemoryMetadataName
(
name_
));
metadata_shared_memory_
->
Open
(
max_metadata_size_
);
auto
archive
=
this
->
ReadTorchArchive
();
int64_t
data_size
=
read_from_archive
(
archive
,
"data_size"
).
toInt
();
data_shared_memory_
=
std
::
make_unique
<
SharedMemory
>
(
GetSharedMemoryDataName
(
name_
));
data_shared_memory_
->
Open
(
data_size
);
}
else
{
// Writer process already has the shared memory.
// Skip the first archive recording data size before read.
this
->
ReadTorchArchive
();
}
}
return
shm
;
}
}
/**
void
SharedMemoryHelper
::
WriteTorchArchive
(
* @brief Load tensors data from shared memory.
torch
::
serialize
::
OutputArchive
&&
archive
)
{
* @param name The name of shared memory.
metadata_to_write_
.
emplace_back
(
std
::
move
(
archive
));
* @param tensor_metas The meta info of tensors, including a flag indicating
}
* whether the optional tensor has value, tensor shape and dtype.
*
torch
::
serialize
::
InputArchive
SharedMemoryHelper
::
ReadTorchArchive
()
{
* @return A pair of shared memory holding the tensors.
auto
metadata_ptr
=
this
->
GetCurrentMetadataPtr
();
*/
int64_t
metadata_size
=
static_cast
<
int64_t
*>
(
metadata_ptr
)[
0
];
static
std
::
pair
<
SharedMemoryPtr
,
std
::
vector
<
torch
::
optional
<
torch
::
Tensor
>>>
torch
::
serialize
::
InputArchive
archive
;
LoadTensorsDataFromSharedMemory
(
archive
.
load_from
(
const
std
::
string
&
name
,
static_cast
<
const
char
*>
(
metadata_ptr
)
+
sizeof
(
int64_t
),
metadata_size
);
const
std
::
vector
<
auto
rounded_size
=
GetRoundedSize
(
metadata_size
);
std
::
tuple
<
bool
,
std
::
vector
<
int64_t
>
,
torch
::
ScalarType
>>&
this
->
MoveMetadataPtr
(
sizeof
(
int64_t
)
+
rounded_size
);
tensor_metas
)
{
return
archive
;
auto
shm
=
std
::
make_unique
<
SharedMemory
>
(
name
);
}
int64_t
memory_size
=
0
;
for
(
const
auto
&
meta
:
tensor_metas
)
{
void
SharedMemoryHelper
::
WriteTorchTensor
(
if
(
std
::
get
<
0
>
(
meta
)
)
{
torch
::
optional
<
torch
::
Tensor
>
tensor
)
{
int64_t
size
=
std
::
accumulate
(
torch
::
serialize
::
OutputArchive
archive
;
std
::
get
<
1
>
(
meta
).
begin
(),
std
::
get
<
1
>
(
meta
).
end
(),
1
,
archive
.
write
(
"has_value"
,
tensor
.
has_value
());
std
::
multiplies
<
int64_t
>
())
;
if
(
tensor
.
has_value
())
{
memory_size
+=
size
*
torch
::
elementSize
(
std
::
get
<
2
>
(
meta
));
archive
.
write
(
"shape"
,
tensor
.
value
().
sizes
(
));
}
archive
.
write
(
"dtype"
,
tensor
.
value
().
scalar_type
());
}
}
auto
mem_buf
=
shm
->
Open
(
memory_size
);
this
->
WriteTorchArchive
(
std
::
move
(
archive
));
std
::
vector
<
torch
::
optional
<
torch
::
Tensor
>>
optional_tensors
;
tensors_to_write_
.
push_back
(
tensor
);
for
(
const
auto
&
meta
:
tensor_metas
)
{
}
if
(
std
::
get
<
0
>
(
meta
))
{
auto
tensor
=
torch
::
optional
<
torch
::
Tensor
>
SharedMemoryHelper
::
ReadTorchTensor
()
{
torch
::
from_blob
(
mem_buf
,
std
::
get
<
1
>
(
meta
),
std
::
get
<
2
>
(
meta
));
auto
archive
=
this
->
ReadTorchArchive
();
optional_tensors
.
push_back
(
tensor
);
bool
has_value
=
read_from_archive
(
archive
,
"has_value"
).
toBool
();
int64_t
size
=
std
::
accumulate
(
if
(
has_value
)
{
std
::
get
<
1
>
(
meta
).
begin
(),
std
::
get
<
1
>
(
meta
).
end
(),
1
,
auto
shape
=
read_from_archive
(
archive
,
"shape"
).
toIntVector
();
std
::
multiplies
<
int64_t
>
());
auto
dtype
=
read_from_archive
(
archive
,
"dtype"
).
toScalarType
();
mem_buf
=
static_cast
<
char
*>
(
mem_buf
)
+
auto
data_ptr
=
this
->
GetCurrentDataPtr
();
size
*
torch
::
elementSize
(
std
::
get
<
2
>
(
meta
));
auto
tensor
=
torch
::
from_blob
(
data_ptr
,
shape
,
dtype
);
}
else
{
auto
rounded_size
=
GetRoundedSize
(
tensor
.
numel
()
*
tensor
.
element_size
());
optional_tensors
.
push_back
(
torch
::
nullopt
);
this
->
MoveDataPtr
(
rounded_size
);
}
return
tensor
;
}
else
{
return
torch
::
nullopt
;
}
}
return
std
::
make_pair
(
std
::
move
(
shm
),
std
::
move
(
optional_tensors
));
}
}
SharedMemoryTensors
CopyTensorsToSharedMemory
(
void
SharedMemoryHelper
::
WriteTorchTensorDict
(
const
std
::
string
&
name
,
torch
::
optional
<
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>>
tensor_dict
)
{
const
std
::
vector
<
torch
::
optional
<
torch
::
Tensor
>>&
tensors
,
int64_t
max_meta_memory_size
)
{
torch
::
serialize
::
OutputArchive
archive
;
torch
::
serialize
::
OutputArchive
archive
;
archive
.
write
(
"num_tensors"
,
static_cast
<
int64_t
>
(
tensors
.
size
()));
if
(
!
tensor_dict
.
has_value
())
{
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
archive
.
write
(
"has_value"
,
false
);
archive
.
write
(
this
->
WriteTorchArchive
(
std
::
move
(
archive
));
"tensor_"
+
std
::
to_string
(
i
)
+
"_has_value"
,
tensors
[
i
].
has_value
());
return
;
if
(
tensors
[
i
].
has_value
())
{
archive
.
write
(
"tensor_"
+
std
::
to_string
(
i
)
+
"_shape"
,
tensors
[
i
].
value
().
sizes
());
archive
.
write
(
"tensor_"
+
std
::
to_string
(
i
)
+
"_dtype"
,
tensors
[
i
].
value
().
scalar_type
());
}
}
}
auto
meta_shm
=
CopyTorchArchiveToSharedMemory
(
archive
.
write
(
"has_value"
,
true
);
name
+
"_meta"
,
max_meta_memory_size
,
archive
);
auto
dict_value
=
tensor_dict
.
value
();
auto
data_shm
=
CopyTensorsDataToSharedMemory
(
name
+
"_data"
,
tensors
);
archive
.
write
(
"num_tensors"
,
static_cast
<
int64_t
>
(
dict_value
.
size
()));
int
counter
=
0
;
std
::
vector
<
torch
::
optional
<
torch
::
Tensor
>>
ret_tensors
;
for
(
auto
it
=
dict_value
.
begin
();
it
!=
dict_value
.
end
();
++
it
)
{
auto
mem_buf
=
data_shm
->
GetMemory
();
archive
.
write
(
std
::
string
(
"key_"
)
+
std
::
to_string
(
counter
),
it
->
key
());
for
(
auto
optional_tensor
:
tensors
)
{
counter
++
;
if
(
optional_tensor
.
has_value
())
{
}
auto
tensor
=
optional_tensor
.
value
();
this
->
WriteTorchArchive
(
std
::
move
(
archive
));
ret_tensors
.
push_back
(
for
(
auto
it
=
dict_value
.
begin
();
it
!=
dict_value
.
end
();
++
it
)
{
torch
::
from_blob
(
mem_buf
,
tensor
.
sizes
(),
tensor
.
dtype
()));
this
->
WriteTorchTensor
(
it
->
value
());
int64_t
size
=
tensor
.
numel
()
*
tensor
.
element_size
();
mem_buf
=
static_cast
<
char
*>
(
mem_buf
)
+
size
;
}
else
{
ret_tensors
.
push_back
(
torch
::
nullopt
);
}
}
}
return
std
::
make_tuple
(
std
::
move
(
meta_shm
),
std
::
move
(
data_shm
),
std
::
move
(
ret_tensors
));
}
}
SharedMemoryTensors
LoadTensorsFromSharedMemory
(
torch
::
optional
<
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>>
const
std
::
string
&
name
,
int64_t
meta_memory_size
)
{
SharedMemoryHelper
::
ReadTorchTensorDict
(
)
{
torch
::
serialize
::
InputArchive
a
rchive
;
auto
archive
=
this
->
ReadTorchA
rchive
()
;
auto
meta_shm
=
LoadTorchArchiveFromSharedMemory
(
if
(
!
read_from_archive
(
archive
,
"has_value"
).
toBool
())
{
name
+
"_meta"
,
meta_memory_size
,
archive
)
;
return
torch
::
nullopt
;
std
::
vector
<
std
::
tuple
<
bool
,
std
::
vector
<
int64_t
>
,
torch
::
ScalarType
>>
metas
;
}
int64_t
num_tensors
=
read_from_archive
(
archive
,
"num_tensors"
).
toInt
();
int64_t
num_tensors
=
read_from_archive
(
archive
,
"num_tensors"
).
toInt
();
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>
tensor_dict
;
for
(
int64_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
bool
has_value
=
auto
key
=
read_from_archive
(
archive
,
"tensor_"
+
std
::
to_string
(
i
)
+
"_has_value"
)
read_from_archive
(
archive
,
std
::
string
(
"key_"
)
+
std
::
to_string
(
i
))
.
toBool
();
.
toStringRef
();
if
(
has_value
)
{
auto
tensor
=
this
->
ReadTorchTensor
();
auto
shape
=
tensor_dict
.
insert
(
key
,
tensor
.
value
());
read_from_archive
(
archive
,
"tensor_"
+
std
::
to_string
(
i
)
+
"_shape"
)
}
.
toIntVector
();
return
tensor_dict
;
auto
dtype
=
}
read_from_archive
(
archive
,
"tensor_"
+
std
::
to_string
(
i
)
+
"_dtype"
)
.
toScalarType
();
void
SharedMemoryHelper
::
WriteTorchArchiveInternal
(
metas
.
push_back
({
true
,
shape
,
dtype
});
torch
::
serialize
::
OutputArchive
&
archive
)
{
}
else
{
std
::
stringstream
serialized
;
metas
.
push_back
({
false
,
{},
torch
::
ScalarType
::
Undefined
});
archive
.
save_to
(
serialized
);
auto
serialized_str
=
serialized
.
str
();
auto
metadata_ptr
=
this
->
GetCurrentMetadataPtr
();
static_cast
<
int64_t
*>
(
metadata_ptr
)[
0
]
=
serialized_str
.
size
();
memcpy
(
static_cast
<
char
*>
(
metadata_ptr
)
+
sizeof
(
int64_t
),
serialized_str
.
data
(),
serialized_str
.
size
());
int64_t
rounded_size
=
GetRoundedSize
(
serialized_str
.
size
());
this
->
MoveMetadataPtr
(
sizeof
(
int64_t
)
+
rounded_size
);
}
void
SharedMemoryHelper
::
WriteTorchTensorInternal
(
torch
::
optional
<
torch
::
Tensor
>
tensor
)
{
if
(
tensor
.
has_value
())
{
size_t
memory_size
=
tensor
.
value
().
numel
()
*
tensor
.
value
().
element_size
();
auto
data_ptr
=
this
->
GetCurrentDataPtr
();
auto
contiguous_tensor
=
tensor
.
value
().
contiguous
();
memcpy
(
data_ptr
,
contiguous_tensor
.
data_ptr
(),
memory_size
);
this
->
MoveDataPtr
(
GetRoundedSize
(
memory_size
));
}
}
void
SharedMemoryHelper
::
Flush
()
{
// The first archive records the size of the tensor data.
torch
::
serialize
::
OutputArchive
archive
;
size_t
data_size
=
0
;
for
(
auto
tensor
:
tensors_to_write_
)
{
if
(
tensor
.
has_value
())
{
auto
tensor_size
=
tensor
.
value
().
numel
()
*
tensor
.
value
().
element_size
();
data_size
+=
GetRoundedSize
(
tensor_size
);
}
}
}
}
return
std
::
tuple_cat
(
archive
.
write
(
"data_size"
,
static_cast
<
int64_t
>
(
data_size
));
std
::
forward_as_tuple
(
std
::
move
(
meta_shm
)),
metadata_shared_memory_
=
LoadTensorsDataFromSharedMemory
(
name
+
"_data"
,
metas
));
std
::
make_unique
<
SharedMemory
>
(
GetSharedMemoryMetadataName
(
name_
));
metadata_shared_memory_
->
Create
(
max_metadata_size_
);
metadata_offset_
=
0
;
this
->
WriteTorchArchiveInternal
(
archive
);
for
(
auto
&
archive
:
metadata_to_write_
)
{
this
->
WriteTorchArchiveInternal
(
archive
);
}
data_shared_memory_
=
std
::
make_unique
<
SharedMemory
>
(
GetSharedMemoryDataName
(
name_
));
data_shared_memory_
->
Create
(
data_size
);
data_offset_
=
0
;
for
(
auto
tensor
:
tensors_to_write_
)
{
this
->
WriteTorchTensorInternal
(
tensor
);
}
metadata_to_write_
.
clear
();
tensors_to_write_
.
clear
();
}
std
::
pair
<
SharedMemoryPtr
,
SharedMemoryPtr
>
SharedMemoryHelper
::
ReleaseSharedMemory
()
{
return
std
::
make_pair
(
std
::
move
(
metadata_shared_memory_
),
std
::
move
(
data_shared_memory_
));
}
}
}
// namespace sampling
}
// namespace sampling
...
...
graphbolt/src/shared_memory_utils.h
View file @
aa562f7e
...
@@ -20,52 +20,122 @@ namespace graphbolt {
...
@@ -20,52 +20,122 @@ namespace graphbolt {
namespace
sampling
{
namespace
sampling
{
/**
/**
* @brief SharedMemoryTensors includes: (1) two share memory objects holding
* @brief SharedMemoryHelper is a helper class to write/read data structures
* tensor meta information and data respectively; (2) a vector of optional
* to/from shared memory.
* tensors on shared memory.
*/
using
SharedMemoryTensors
=
std
::
tuple
<
SharedMemoryPtr
,
SharedMemoryPtr
,
std
::
vector
<
torch
::
optional
<
torch
::
Tensor
>>>
;
/**
* @brief Copy torch tensors to shared memory.
*
*
* To simpilfy this interface, a regular tensor is also wrapped as an optional
* In order to write data structure to shared memory, we need to serialize the
* one.
* data structure to a binary buffer and then write the buffer to the shared
* memory. However, the size of the binary buffer is not known in advance. To
* solve this problem, we use two shared memory objects: one for storing the
* metadata and the other for storing the binary buffer. The metadata includes
* the meta information of data structures such as size and shape. The size of
* the metadata is decided by the user via `max_metadata_size`. The size of
* the binary buffer is decided by the size of the data structures.
*
*
* T
he function has two steps:
* T
o avoid repeated shared memory allocation, this helper class uses lazy data
*
1. Copy meta info to shared memory `
shared
_
memory
_name + "_meta"`. This is to
*
structure writing. The data structures are written to the
shared
memory
only
*
make sure that other loading processes can get the meta info of tensors.
*
when `Flush` is called. The data structures are written in the order of
*
2. Copy tensors to shared memory `shared_memory_name + "_data"`, which can be
*
calling `WriteTorchArchive`, `WriteTorchTensor` and `WriteTorchTensorDict`,
*
loaded by other processes with meta info
.
*
and also read in the same order
.
*
*
* The order of tensors loaded from `LoadTensorsFromSharedMemory` will be
* The usage of this class as a writer is as follows:
* exactly the same as the tensors copied from `CopyTensorsToSharedMemory`.
* @code{.cpp}
* SharedMemoryHelper shm_helper("shm_name", 1024, true);
* shm_helper.WriteTorchArchive(archive);
* shm_helper.WriteTorchTensor(tensor);
* shm_helper.WriteTorchTensorDict(tensor_dict);
* shm_helper.Flush();
* // After `Flush`, the data structures are written to the shared memory.
* // Then the helper class can be used as a reader.
* shm_helper.InitializeRead();
* auto archive = shm_helper.ReadTorchArchive();
* auto tensor = shm_helper.ReadTorchTensor();
* auto tensor_dict = shm_helper.ReadTorchTensorDict();
* @endcode
*
*
* @param name The name of shared memory.
* The usage of this class as a reader is as follows:
* @param tensors The tensors to copy.
* @code{.cpp}
* @param max_meta_memory_size The maximum size of meta memory.
* SharedMemoryHelper shm_helper("shm_name", 1024, false);
*
* shm_helper.InitializeRead();
* @return A tuple of tensor meta shared memory, tensor data shared memory, and
* auto archive = shm_helper.ReadTorchArchive();
* shared optional tensors.
* auto tensor = shm_helper.ReadTorchTensor();
*/
* auto tensor_dict = shm_helper.ReadTorchTensorDict();
SharedMemoryTensors
CopyTensorsToSharedMemory
(
* @endcode
const
std
::
string
&
name
,
const
std
::
vector
<
torch
::
optional
<
torch
::
Tensor
>>&
tensors
,
int64_t
max_meta_memory_size
);
/**
* @brief Load torch tensors from shared memory.
*
*
* @param name The name of shared memory.
* @param max_meta_memory_size The maximum size of meta memory.
*
*
* @return A tuple of tensor meta shared memory, tensor data shared memory,
* and shared tensors.
*/
*/
SharedMemoryTensors
LoadTensorsFromSharedMemory
(
class
SharedMemoryHelper
{
const
std
::
string
&
name
,
int64_t
max_meta_memory_size
);
public:
/**
* @brief Constructor of the shared memory helper.
* @param name The name of the shared memory.
* @param max_metadata_size The maximum size of metadata.
*/
SharedMemoryHelper
(
const
std
::
string
&
name
,
int64_t
max_metadata_size
);
/** @brief Initialize this helper class before reading. */
void
InitializeRead
();
void
WriteTorchArchive
(
torch
::
serialize
::
OutputArchive
&&
archive
);
torch
::
serialize
::
InputArchive
ReadTorchArchive
();
void
WriteTorchTensor
(
torch
::
optional
<
torch
::
Tensor
>
tensor
);
torch
::
optional
<
torch
::
Tensor
>
ReadTorchTensor
();
void
WriteTorchTensorDict
(
torch
::
optional
<
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>>
tensor_dict
);
torch
::
optional
<
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>>
ReadTorchTensorDict
();
/** @brief Flush the data structures to the shared memory. */
void
Flush
();
/** @brief Release the shared memory and return their left values. */
std
::
pair
<
SharedMemoryPtr
,
SharedMemoryPtr
>
ReleaseSharedMemory
();
private:
/**
* @brief Write the metadata to the shared memory. This function is
* called by `Flush`.
*/
void
WriteTorchArchiveInternal
(
torch
::
serialize
::
OutputArchive
&
archive
);
/**
* @brief Write the tensor data to the shared memory. This function is
* called by `Flush`.
*/
void
WriteTorchTensorInternal
(
torch
::
optional
<
torch
::
Tensor
>
tensor
);
inline
void
*
GetCurrentMetadataPtr
()
const
{
return
static_cast
<
char
*>
(
metadata_shared_memory_
->
GetMemory
())
+
metadata_offset_
;
}
inline
void
*
GetCurrentDataPtr
()
const
{
return
static_cast
<
char
*>
(
data_shared_memory_
->
GetMemory
())
+
data_offset_
;
}
inline
void
MoveMetadataPtr
(
int64_t
offset
)
{
TORCH_CHECK
(
metadata_offset_
+
offset
<=
max_metadata_size_
,
"The size of metadata exceeds the maximum size of shared memory."
);
metadata_offset_
+=
offset
;
}
inline
void
MoveDataPtr
(
int64_t
offset
)
{
data_offset_
+=
offset
;
}
std
::
string
name_
;
bool
is_creator_
;
int64_t
max_metadata_size_
;
// The shared memory objects for storing metadata and tensor data.
SharedMemoryPtr
metadata_shared_memory_
,
data_shared_memory_
;
// The read/write offsets of the metadata and tensor data.
int64_t
metadata_offset_
,
data_offset_
;
// The data structures to write to the shared memory. They are written to the
// shared memory only when `Flush` is called.
std
::
vector
<
torch
::
serialize
::
OutputArchive
>
metadata_to_write_
;
std
::
vector
<
torch
::
optional
<
torch
::
Tensor
>>
tensors_to_write_
;
};
}
// namespace sampling
}
// namespace sampling
}
// namespace graphbolt
}
// namespace graphbolt
...
...
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
View file @
aa562f7e
...
@@ -807,9 +807,17 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
...
@@ -807,9 +807,17 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"num_nodes, num_edges"
,
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)]
"num_nodes, num_edges"
,
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)]
)
)
def
test_homo_graph_on_shared_memory
(
num_nodes
,
num_edges
):
@
pytest
.
mark
.
parametrize
(
"test_edge_attrs"
,
[
True
,
False
])
def
test_homo_graph_on_shared_memory
(
num_nodes
,
num_edges
,
test_edge_attrs
):
csc_indptr
,
indices
=
gbt
.
random_homo_graph
(
num_nodes
,
num_edges
)
csc_indptr
,
indices
=
gbt
.
random_homo_graph
(
num_nodes
,
num_edges
)
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
)
if
test_edge_attrs
:
edge_attributes
=
{
"A1"
:
torch
.
randn
(
num_edges
),
"A2"
:
torch
.
randn
(
num_edges
),
}
else
:
edge_attributes
=
None
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
,
edge_attributes
=
edge_attributes
)
shm_name
=
"test_homo_g"
shm_name
=
"test_homo_g"
graph1
=
graph
.
copy_to_shared_memory
(
shm_name
)
graph1
=
graph
.
copy_to_shared_memory
(
shm_name
)
...
@@ -834,6 +842,15 @@ def test_homo_graph_on_shared_memory(num_nodes, num_edges):
...
@@ -834,6 +842,15 @@ def test_homo_graph_on_shared_memory(num_nodes, num_edges):
)
)
check_tensors_on_the_same_shared_memory
(
graph1
.
indices
,
graph2
.
indices
)
check_tensors_on_the_same_shared_memory
(
graph1
.
indices
,
graph2
.
indices
)
if
test_edge_attrs
:
for
name
,
edge_attr
in
edge_attributes
.
items
():
assert
name
in
graph1
.
edge_attributes
assert
name
in
graph2
.
edge_attributes
assert
torch
.
equal
(
graph1
.
edge_attributes
[
name
],
edge_attr
)
check_tensors_on_the_same_shared_memory
(
graph1
.
edge_attributes
[
name
],
graph2
.
edge_attributes
[
name
]
)
assert
graph1
.
metadata
is
None
and
graph2
.
metadata
is
None
assert
graph1
.
metadata
is
None
and
graph2
.
metadata
is
None
assert
graph1
.
node_type_offset
is
None
and
graph2
.
node_type_offset
is
None
assert
graph1
.
node_type_offset
is
None
and
graph2
.
node_type_offset
is
None
assert
graph1
.
type_per_edge
is
None
and
graph2
.
type_per_edge
is
None
assert
graph1
.
type_per_edge
is
None
and
graph2
.
type_per_edge
is
None
...
@@ -847,8 +864,9 @@ def test_homo_graph_on_shared_memory(num_nodes, num_edges):
...
@@ -847,8 +864,9 @@ def test_homo_graph_on_shared_memory(num_nodes, num_edges):
"num_nodes, num_edges"
,
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)]
"num_nodes, num_edges"
,
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)]
)
)
@
pytest
.
mark
.
parametrize
(
"num_ntypes, num_etypes"
,
[(
1
,
1
),
(
3
,
5
),
(
100
,
1
)])
@
pytest
.
mark
.
parametrize
(
"num_ntypes, num_etypes"
,
[(
1
,
1
),
(
3
,
5
),
(
100
,
1
)])
@
pytest
.
mark
.
parametrize
(
"test_edge_attrs"
,
[
True
,
False
])
def
test_hetero_graph_on_shared_memory
(
def
test_hetero_graph_on_shared_memory
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
,
test_edge_attrs
):
):
(
(
csc_indptr
,
csc_indptr
,
...
@@ -857,8 +875,21 @@ def test_hetero_graph_on_shared_memory(
...
@@ -857,8 +875,21 @@ def test_hetero_graph_on_shared_memory(
type_per_edge
,
type_per_edge
,
metadata
,
metadata
,
)
=
gbt
.
random_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
)
)
=
gbt
.
random_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
)
if
test_edge_attrs
:
edge_attributes
=
{
"A1"
:
torch
.
randn
(
num_edges
),
"A2"
:
torch
.
randn
(
num_edges
),
}
else
:
edge_attributes
=
None
graph
=
gb
.
from_csc
(
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
None
,
metadata
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
edge_attributes
,
metadata
,
)
)
shm_name
=
"test_hetero_g"
shm_name
=
"test_hetero_g"
...
@@ -894,6 +925,15 @@ def test_hetero_graph_on_shared_memory(
...
@@ -894,6 +925,15 @@ def test_hetero_graph_on_shared_memory(
graph1
.
type_per_edge
,
graph2
.
type_per_edge
graph1
.
type_per_edge
,
graph2
.
type_per_edge
)
)
if
test_edge_attrs
:
for
name
,
edge_attr
in
edge_attributes
.
items
():
assert
name
in
graph1
.
edge_attributes
assert
name
in
graph2
.
edge_attributes
assert
torch
.
equal
(
graph1
.
edge_attributes
[
name
],
edge_attr
)
check_tensors_on_the_same_shared_memory
(
graph1
.
edge_attributes
[
name
],
graph2
.
edge_attributes
[
name
]
)
assert
metadata
.
node_type_to_id
==
graph1
.
metadata
.
node_type_to_id
assert
metadata
.
node_type_to_id
==
graph1
.
metadata
.
node_type_to_id
assert
metadata
.
edge_type_to_id
==
graph1
.
metadata
.
edge_type_to_id
assert
metadata
.
edge_type_to_id
==
graph1
.
metadata
.
edge_type_to_id
assert
metadata
.
node_type_to_id
==
graph2
.
metadata
.
node_type_to_id
assert
metadata
.
node_type_to_id
==
graph2
.
metadata
.
node_type_to_id
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment