"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "97b66318e5f3c5853c7ac1c76187e3f77f825cb3"
Unverified Commit aa562f7e authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Refactor shared memory utility. (#6198)

parent df1ea757
......@@ -31,6 +31,8 @@ struct SamplerArgs<SamplerType::LABOR> {
int64_t num_nodes;
};
class SharedMemoryHelper;
/**
* @brief A sampling oriented csc format graph.
*
......@@ -243,18 +245,14 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
PickFn pick_fn) const;
/**
* @brief Build a CSCSamplingGraph from shared memory tensors.
*
* @param shared_memory_tensors A tuple of two share memory objects holding
* tensor meta information and data respectively, and a vector of optional
* tensors on shared memory.
* @brief Build a CSCSamplingGraph from a shared memory helper. This function
* takes ownership of the shared memory objects in the helper.
*
* @param shared_memory_helper The shared memory helper.
* @return A new CSCSamplingGraph on shared memory.
*/
static c10::intrusive_ptr<CSCSamplingGraph> BuildGraphFromSharedMemoryTensors(
std::tuple<
SharedMemoryPtr, SharedMemoryPtr,
std::vector<torch::optional<torch::Tensor>>>&& shared_memory_tensors);
static c10::intrusive_ptr<CSCSamplingGraph> BuildGraphFromSharedMemoryHelper(
SharedMemoryHelper&& shared_memory_helper);
/** @brief CSC format index pointer array. */
torch::Tensor indptr_;
......
......@@ -439,33 +439,38 @@ CSCSamplingGraph::SampleNegativeEdgesUniform(
}
c10::intrusive_ptr<CSCSamplingGraph>
CSCSamplingGraph::BuildGraphFromSharedMemoryTensors(
std::tuple<
SharedMemoryPtr, SharedMemoryPtr,
std::vector<torch::optional<torch::Tensor>>>&& shared_memory_tensors) {
auto& optional_tensors = std::get<2>(shared_memory_tensors);
CSCSamplingGraph::BuildGraphFromSharedMemoryHelper(
SharedMemoryHelper&& helper) {
helper.InitializeRead();
auto indptr = helper.ReadTorchTensor();
auto indices = helper.ReadTorchTensor();
auto node_type_offset = helper.ReadTorchTensor();
auto type_per_edge = helper.ReadTorchTensor();
auto edge_attributes = helper.ReadTorchTensorDict();
auto graph = c10::make_intrusive<CSCSamplingGraph>(
optional_tensors[0].value(), optional_tensors[1].value(),
optional_tensors[2], optional_tensors[3], torch::nullopt);
graph->tensor_meta_shm_ = std::move(std::get<0>(shared_memory_tensors));
graph->tensor_data_shm_ = std::move(std::get<1>(shared_memory_tensors));
indptr.value(), indices.value(), node_type_offset, type_per_edge,
edge_attributes);
std::tie(graph->tensor_meta_shm_, graph->tensor_data_shm_) =
helper.ReleaseSharedMemory();
return graph;
}
c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::CopyToSharedMemory(
const std::string& shared_memory_name) {
auto optional_tensors = std::vector<torch::optional<torch::Tensor>>{
indptr_, indices_, node_type_offset_, type_per_edge_};
auto shared_memory_tensors = CopyTensorsToSharedMemory(
shared_memory_name, optional_tensors, SERIALIZED_METAINFO_SIZE_MAX);
return BuildGraphFromSharedMemoryTensors(std::move(shared_memory_tensors));
SharedMemoryHelper helper(shared_memory_name, SERIALIZED_METAINFO_SIZE_MAX);
helper.WriteTorchTensor(indptr_);
helper.WriteTorchTensor(indices_);
helper.WriteTorchTensor(node_type_offset_);
helper.WriteTorchTensor(type_per_edge_);
helper.WriteTorchTensorDict(edge_attributes_);
helper.Flush();
return BuildGraphFromSharedMemoryHelper(std::move(helper));
}
c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory(
const std::string& shared_memory_name) {
auto shared_memory_tensors = LoadTensorsFromSharedMemory(
shared_memory_name, SERIALIZED_METAINFO_SIZE_MAX);
return BuildGraphFromSharedMemoryTensors(std::move(shared_memory_tensors));
SharedMemoryHelper helper(shared_memory_name, SERIALIZED_METAINFO_SIZE_MAX);
return BuildGraphFromSharedMemoryHelper(std::move(helper));
}
int64_t NumPick(
......
......@@ -18,163 +18,194 @@
namespace graphbolt {
namespace sampling {
static SharedMemoryPtr CopyTorchArchiveToSharedMemory(
const std::string& name, int64_t size,
torch::serialize::OutputArchive& archive) {
std::stringstream serialized;
archive.save_to(serialized);
auto serialized_str = serialized.str();
auto shm = std::make_unique<SharedMemory>(name);
auto mem_buf = shm->Create(size);
// Use the first 8 bytes to store the size of the serialized string.
static_cast<int64_t*>(mem_buf)[0] = serialized_str.size();
memcpy(
(char*)mem_buf + sizeof(int64_t), serialized_str.data(),
serialized_str.size());
return shm;
static std::string GetSharedMemoryMetadataName(const std::string& name) {
return name + "_meta";
}
static SharedMemoryPtr LoadTorchArchiveFromSharedMemory(
const std::string& name, int64_t max_meta_size,
torch::serialize::InputArchive& archive) {
auto shm = std::make_unique<SharedMemory>(name);
auto mem_buf = shm->Open(max_meta_size);
int64_t meta_size = static_cast<int64_t*>(mem_buf)[0];
archive.load_from(
static_cast<const char*>(mem_buf) + sizeof(int64_t), meta_size);
return shm;
static std::string GetSharedMemoryDataName(const std::string& name) {
return name + "_data";
}
static SharedMemoryPtr CopyTensorsDataToSharedMemory(
const std::string& name,
const std::vector<torch::optional<torch::Tensor>>& tensors) {
int64_t memory_size = 0;
for (const auto& optional_tensor : tensors) {
if (optional_tensor.has_value()) {
auto tensor = optional_tensor.value();
memory_size += tensor.numel() * tensor.element_size();
}
}
auto shm = std::make_unique<SharedMemory>(name);
auto mem_buf = shm->Create(memory_size);
for (auto optional_tensor : tensors) {
if (optional_tensor.has_value()) {
auto tensor = optional_tensor.value().contiguous();
int64_t size = tensor.numel() * tensor.element_size();
memcpy(mem_buf, tensor.data_ptr(), size);
mem_buf = static_cast<char*>(mem_buf) + size;
}
// To avoid unaligned memory access, we round the size of the binary buffer to
// the nearest multiple of 8 bytes.
inline static int64_t GetRoundedSize(int64_t size) {
constexpr int64_t ALIGNED_SIZE = 8;
return (size + ALIGNED_SIZE - 1) / ALIGNED_SIZE * ALIGNED_SIZE;
}
SharedMemoryHelper::SharedMemoryHelper(
const std::string& name, int64_t max_metadata_size)
: name_(name),
max_metadata_size_(max_metadata_size),
metadata_shared_memory_(nullptr),
data_shared_memory_(nullptr),
metadata_offset_(0),
data_offset_(0) {}
void SharedMemoryHelper::InitializeRead() {
metadata_offset_ = 0;
data_offset_ = 0;
if (metadata_shared_memory_ == nullptr) {
// Reader process opens the shared memory.
metadata_shared_memory_ =
std::make_unique<SharedMemory>(GetSharedMemoryMetadataName(name_));
metadata_shared_memory_->Open(max_metadata_size_);
auto archive = this->ReadTorchArchive();
int64_t data_size = read_from_archive(archive, "data_size").toInt();
data_shared_memory_ =
std::make_unique<SharedMemory>(GetSharedMemoryDataName(name_));
data_shared_memory_->Open(data_size);
} else {
// Writer process already has the shared memory.
// Skip the first archive recording data size before read.
this->ReadTorchArchive();
}
return shm;
}
/**
* @brief Load tensors data from shared memory.
* @param name The name of shared memory.
* @param tensor_metas The meta info of tensors, including a flag indicating
* whether the optional tensor has value, tensor shape and dtype.
*
* @return A pair of shared memory holding the tensors.
*/
static std::pair<SharedMemoryPtr, std::vector<torch::optional<torch::Tensor>>>
LoadTensorsDataFromSharedMemory(
const std::string& name,
const std::vector<
std::tuple<bool, std::vector<int64_t>, torch::ScalarType>>&
tensor_metas) {
auto shm = std::make_unique<SharedMemory>(name);
int64_t memory_size = 0;
for (const auto& meta : tensor_metas) {
if (std::get<0>(meta)) {
int64_t size = std::accumulate(
std::get<1>(meta).begin(), std::get<1>(meta).end(), 1,
std::multiplies<int64_t>());
memory_size += size * torch::elementSize(std::get<2>(meta));
}
void SharedMemoryHelper::WriteTorchArchive(
torch::serialize::OutputArchive&& archive) {
metadata_to_write_.emplace_back(std::move(archive));
}
torch::serialize::InputArchive SharedMemoryHelper::ReadTorchArchive() {
auto metadata_ptr = this->GetCurrentMetadataPtr();
int64_t metadata_size = static_cast<int64_t*>(metadata_ptr)[0];
torch::serialize::InputArchive archive;
archive.load_from(
static_cast<const char*>(metadata_ptr) + sizeof(int64_t), metadata_size);
auto rounded_size = GetRoundedSize(metadata_size);
this->MoveMetadataPtr(sizeof(int64_t) + rounded_size);
return archive;
}
void SharedMemoryHelper::WriteTorchTensor(
torch::optional<torch::Tensor> tensor) {
torch::serialize::OutputArchive archive;
archive.write("has_value", tensor.has_value());
if (tensor.has_value()) {
archive.write("shape", tensor.value().sizes());
archive.write("dtype", tensor.value().scalar_type());
}
auto mem_buf = shm->Open(memory_size);
std::vector<torch::optional<torch::Tensor>> optional_tensors;
for (const auto& meta : tensor_metas) {
if (std::get<0>(meta)) {
auto tensor =
torch::from_blob(mem_buf, std::get<1>(meta), std::get<2>(meta));
optional_tensors.push_back(tensor);
int64_t size = std::accumulate(
std::get<1>(meta).begin(), std::get<1>(meta).end(), 1,
std::multiplies<int64_t>());
mem_buf = static_cast<char*>(mem_buf) +
size * torch::elementSize(std::get<2>(meta));
} else {
optional_tensors.push_back(torch::nullopt);
}
this->WriteTorchArchive(std::move(archive));
tensors_to_write_.push_back(tensor);
}
torch::optional<torch::Tensor> SharedMemoryHelper::ReadTorchTensor() {
auto archive = this->ReadTorchArchive();
bool has_value = read_from_archive(archive, "has_value").toBool();
if (has_value) {
auto shape = read_from_archive(archive, "shape").toIntVector();
auto dtype = read_from_archive(archive, "dtype").toScalarType();
auto data_ptr = this->GetCurrentDataPtr();
auto tensor = torch::from_blob(data_ptr, shape, dtype);
auto rounded_size = GetRoundedSize(tensor.numel() * tensor.element_size());
this->MoveDataPtr(rounded_size);
return tensor;
} else {
return torch::nullopt;
}
return std::make_pair(std::move(shm), std::move(optional_tensors));
}
SharedMemoryTensors CopyTensorsToSharedMemory(
const std::string& name,
const std::vector<torch::optional<torch::Tensor>>& tensors,
int64_t max_meta_memory_size) {
void SharedMemoryHelper::WriteTorchTensorDict(
torch::optional<torch::Dict<std::string, torch::Tensor>> tensor_dict) {
torch::serialize::OutputArchive archive;
archive.write("num_tensors", static_cast<int64_t>(tensors.size()));
for (size_t i = 0; i < tensors.size(); ++i) {
archive.write(
"tensor_" + std::to_string(i) + "_has_value", tensors[i].has_value());
if (tensors[i].has_value()) {
archive.write(
"tensor_" + std::to_string(i) + "_shape", tensors[i].value().sizes());
archive.write(
"tensor_" + std::to_string(i) + "_dtype",
tensors[i].value().scalar_type());
}
if (!tensor_dict.has_value()) {
archive.write("has_value", false);
this->WriteTorchArchive(std::move(archive));
return;
}
auto meta_shm = CopyTorchArchiveToSharedMemory(
name + "_meta", max_meta_memory_size, archive);
auto data_shm = CopyTensorsDataToSharedMemory(name + "_data", tensors);
std::vector<torch::optional<torch::Tensor>> ret_tensors;
auto mem_buf = data_shm->GetMemory();
for (auto optional_tensor : tensors) {
if (optional_tensor.has_value()) {
auto tensor = optional_tensor.value();
ret_tensors.push_back(
torch::from_blob(mem_buf, tensor.sizes(), tensor.dtype()));
int64_t size = tensor.numel() * tensor.element_size();
mem_buf = static_cast<char*>(mem_buf) + size;
} else {
ret_tensors.push_back(torch::nullopt);
}
archive.write("has_value", true);
auto dict_value = tensor_dict.value();
archive.write("num_tensors", static_cast<int64_t>(dict_value.size()));
int counter = 0;
for (auto it = dict_value.begin(); it != dict_value.end(); ++it) {
archive.write(std::string("key_") + std::to_string(counter), it->key());
counter++;
}
this->WriteTorchArchive(std::move(archive));
for (auto it = dict_value.begin(); it != dict_value.end(); ++it) {
this->WriteTorchTensor(it->value());
}
return std::make_tuple(
std::move(meta_shm), std::move(data_shm), std::move(ret_tensors));
}
SharedMemoryTensors LoadTensorsFromSharedMemory(
const std::string& name, int64_t meta_memory_size) {
torch::serialize::InputArchive archive;
auto meta_shm = LoadTorchArchiveFromSharedMemory(
name + "_meta", meta_memory_size, archive);
std::vector<std::tuple<bool, std::vector<int64_t>, torch::ScalarType>> metas;
torch::optional<torch::Dict<std::string, torch::Tensor>>
SharedMemoryHelper::ReadTorchTensorDict() {
auto archive = this->ReadTorchArchive();
if (!read_from_archive(archive, "has_value").toBool()) {
return torch::nullopt;
}
int64_t num_tensors = read_from_archive(archive, "num_tensors").toInt();
torch::Dict<std::string, torch::Tensor> tensor_dict;
for (int64_t i = 0; i < num_tensors; ++i) {
bool has_value =
read_from_archive(archive, "tensor_" + std::to_string(i) + "_has_value")
.toBool();
if (has_value) {
auto shape =
read_from_archive(archive, "tensor_" + std::to_string(i) + "_shape")
.toIntVector();
auto dtype =
read_from_archive(archive, "tensor_" + std::to_string(i) + "_dtype")
.toScalarType();
metas.push_back({true, shape, dtype});
} else {
metas.push_back({false, {}, torch::ScalarType::Undefined});
auto key =
read_from_archive(archive, std::string("key_") + std::to_string(i))
.toStringRef();
auto tensor = this->ReadTorchTensor();
tensor_dict.insert(key, tensor.value());
}
return tensor_dict;
}
void SharedMemoryHelper::WriteTorchArchiveInternal(
torch::serialize::OutputArchive& archive) {
std::stringstream serialized;
archive.save_to(serialized);
auto serialized_str = serialized.str();
auto metadata_ptr = this->GetCurrentMetadataPtr();
static_cast<int64_t*>(metadata_ptr)[0] = serialized_str.size();
memcpy(
static_cast<char*>(metadata_ptr) + sizeof(int64_t), serialized_str.data(),
serialized_str.size());
int64_t rounded_size = GetRoundedSize(serialized_str.size());
this->MoveMetadataPtr(sizeof(int64_t) + rounded_size);
}
void SharedMemoryHelper::WriteTorchTensorInternal(
torch::optional<torch::Tensor> tensor) {
if (tensor.has_value()) {
size_t memory_size = tensor.value().numel() * tensor.value().element_size();
auto data_ptr = this->GetCurrentDataPtr();
auto contiguous_tensor = tensor.value().contiguous();
memcpy(data_ptr, contiguous_tensor.data_ptr(), memory_size);
this->MoveDataPtr(GetRoundedSize(memory_size));
}
}
void SharedMemoryHelper::Flush() {
// The first archive records the size of the tensor data.
torch::serialize::OutputArchive archive;
size_t data_size = 0;
for (auto tensor : tensors_to_write_) {
if (tensor.has_value()) {
auto tensor_size = tensor.value().numel() * tensor.value().element_size();
data_size += GetRoundedSize(tensor_size);
}
}
return std::tuple_cat(
std::forward_as_tuple(std::move(meta_shm)),
LoadTensorsDataFromSharedMemory(name + "_data", metas));
archive.write("data_size", static_cast<int64_t>(data_size));
metadata_shared_memory_ =
std::make_unique<SharedMemory>(GetSharedMemoryMetadataName(name_));
metadata_shared_memory_->Create(max_metadata_size_);
metadata_offset_ = 0;
this->WriteTorchArchiveInternal(archive);
for (auto& archive : metadata_to_write_) {
this->WriteTorchArchiveInternal(archive);
}
data_shared_memory_ =
std::make_unique<SharedMemory>(GetSharedMemoryDataName(name_));
data_shared_memory_->Create(data_size);
data_offset_ = 0;
for (auto tensor : tensors_to_write_) {
this->WriteTorchTensorInternal(tensor);
}
metadata_to_write_.clear();
tensors_to_write_.clear();
}
std::pair<SharedMemoryPtr, SharedMemoryPtr>
SharedMemoryHelper::ReleaseSharedMemory() {
return std::make_pair(
std::move(metadata_shared_memory_), std::move(data_shared_memory_));
}
} // namespace sampling
......
......@@ -20,52 +20,122 @@ namespace graphbolt {
namespace sampling {
/**
* @brief SharedMemoryTensors includes: (1) two share memory objects holding
* tensor meta information and data respectively; (2) a vector of optional
* tensors on shared memory.
*/
using SharedMemoryTensors = std::tuple<
SharedMemoryPtr, SharedMemoryPtr,
std::vector<torch::optional<torch::Tensor>>>;
/**
* @brief Copy torch tensors to shared memory.
* @brief SharedMemoryHelper is a helper class to write/read data structures
* to/from shared memory.
*
* To simpilfy this interface, a regular tensor is also wrapped as an optional
* one.
* In order to write data structure to shared memory, we need to serialize the
* data structure to a binary buffer and then write the buffer to the shared
* memory. However, the size of the binary buffer is not known in advance. To
* solve this problem, we use two shared memory objects: one for storing the
* metadata and the other for storing the binary buffer. The metadata includes
* the meta information of data structures such as size and shape. The size of
* the metadata is decided by the user via `max_metadata_size`. The size of
* the binary buffer is decided by the size of the data structures.
*
* The function has two steps:
* 1. Copy meta info to shared memory `shared_memory_name + "_meta"`. This is to
* make sure that other loading processes can get the meta info of tensors.
* 2. Copy tensors to shared memory `shared_memory_name + "_data"`, which can be
* loaded by other processes with meta info.
* To avoid repeated shared memory allocation, this helper class uses lazy data
* structure writing. The data structures are written to the shared memory only
* when `Flush` is called. The data structures are written in the order of
* calling `WriteTorchArchive`, `WriteTorchTensor` and `WriteTorchTensorDict`,
* and also read in the same order.
*
* The order of tensors loaded from `LoadTensorsFromSharedMemory` will be
* exactly the same as the tensors copied from `CopyTensorsToSharedMemory`.
* The usage of this class as a writer is as follows:
* @code{.cpp}
* SharedMemoryHelper shm_helper("shm_name", 1024, true);
* shm_helper.WriteTorchArchive(archive);
* shm_helper.WriteTorchTensor(tensor);
* shm_helper.WriteTorchTensorDict(tensor_dict);
* shm_helper.Flush();
* // After `Flush`, the data structures are written to the shared memory.
* // Then the helper class can be used as a reader.
* shm_helper.InitializeRead();
* auto archive = shm_helper.ReadTorchArchive();
* auto tensor = shm_helper.ReadTorchTensor();
* auto tensor_dict = shm_helper.ReadTorchTensorDict();
* @endcode
*
* @param name The name of shared memory.
* @param tensors The tensors to copy.
* @param max_meta_memory_size The maximum size of meta memory.
*
* @return A tuple of tensor meta shared memory, tensor data shared memory, and
* shared optional tensors.
*/
SharedMemoryTensors CopyTensorsToSharedMemory(
const std::string& name,
const std::vector<torch::optional<torch::Tensor>>& tensors,
int64_t max_meta_memory_size);
/**
* @brief Load torch tensors from shared memory.
* The usage of this class as a reader is as follows:
* @code{.cpp}
* SharedMemoryHelper shm_helper("shm_name", 1024, false);
* shm_helper.InitializeRead();
* auto archive = shm_helper.ReadTorchArchive();
* auto tensor = shm_helper.ReadTorchTensor();
* auto tensor_dict = shm_helper.ReadTorchTensorDict();
* @endcode
*
* @param name The name of shared memory.
* @param max_meta_memory_size The maximum size of meta memory.
*
* @return A tuple of tensor meta shared memory, tensor data shared memory,
* and shared tensors.
*/
SharedMemoryTensors LoadTensorsFromSharedMemory(
const std::string& name, int64_t max_meta_memory_size);
class SharedMemoryHelper {
public:
/**
* @brief Constructor of the shared memory helper.
* @param name The name of the shared memory.
* @param max_metadata_size The maximum size of metadata.
*/
SharedMemoryHelper(const std::string& name, int64_t max_metadata_size);
/** @brief Initialize this helper class before reading. */
void InitializeRead();
void WriteTorchArchive(torch::serialize::OutputArchive&& archive);
torch::serialize::InputArchive ReadTorchArchive();
void WriteTorchTensor(torch::optional<torch::Tensor> tensor);
torch::optional<torch::Tensor> ReadTorchTensor();
void WriteTorchTensorDict(
torch::optional<torch::Dict<std::string, torch::Tensor>> tensor_dict);
torch::optional<torch::Dict<std::string, torch::Tensor>>
ReadTorchTensorDict();
/** @brief Flush the data structures to the shared memory. */
void Flush();
/** @brief Release the shared memory and return their left values. */
std::pair<SharedMemoryPtr, SharedMemoryPtr> ReleaseSharedMemory();
private:
/**
* @brief Write the metadata to the shared memory. This function is
* called by `Flush`.
*/
void WriteTorchArchiveInternal(torch::serialize::OutputArchive& archive);
/**
* @brief Write the tensor data to the shared memory. This function is
* called by `Flush`.
*/
void WriteTorchTensorInternal(torch::optional<torch::Tensor> tensor);
inline void* GetCurrentMetadataPtr() const {
return static_cast<char*>(metadata_shared_memory_->GetMemory()) +
metadata_offset_;
}
inline void* GetCurrentDataPtr() const {
return static_cast<char*>(data_shared_memory_->GetMemory()) + data_offset_;
}
inline void MoveMetadataPtr(int64_t offset) {
TORCH_CHECK(
metadata_offset_ + offset <= max_metadata_size_,
"The size of metadata exceeds the maximum size of shared memory.");
metadata_offset_ += offset;
}
inline void MoveDataPtr(int64_t offset) { data_offset_ += offset; }
std::string name_;
bool is_creator_;
int64_t max_metadata_size_;
// The shared memory objects for storing metadata and tensor data.
SharedMemoryPtr metadata_shared_memory_, data_shared_memory_;
// The read/write offsets of the metadata and tensor data.
int64_t metadata_offset_, data_offset_;
// The data structures to write to the shared memory. They are written to the
// shared memory only when `Flush` is called.
std::vector<torch::serialize::OutputArchive> metadata_to_write_;
std::vector<torch::optional<torch::Tensor>> tensors_to_write_;
};
} // namespace sampling
} // namespace graphbolt
......
......@@ -807,9 +807,17 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
@pytest.mark.parametrize(
"num_nodes, num_edges", [(1, 1), (100, 1), (10, 50), (1000, 50000)]
)
def test_homo_graph_on_shared_memory(num_nodes, num_edges):
@pytest.mark.parametrize("test_edge_attrs", [True, False])
def test_homo_graph_on_shared_memory(num_nodes, num_edges, test_edge_attrs):
csc_indptr, indices = gbt.random_homo_graph(num_nodes, num_edges)
graph = gb.from_csc(csc_indptr, indices)
if test_edge_attrs:
edge_attributes = {
"A1": torch.randn(num_edges),
"A2": torch.randn(num_edges),
}
else:
edge_attributes = None
graph = gb.from_csc(csc_indptr, indices, edge_attributes=edge_attributes)
shm_name = "test_homo_g"
graph1 = graph.copy_to_shared_memory(shm_name)
......@@ -834,6 +842,15 @@ def test_homo_graph_on_shared_memory(num_nodes, num_edges):
)
check_tensors_on_the_same_shared_memory(graph1.indices, graph2.indices)
if test_edge_attrs:
for name, edge_attr in edge_attributes.items():
assert name in graph1.edge_attributes
assert name in graph2.edge_attributes
assert torch.equal(graph1.edge_attributes[name], edge_attr)
check_tensors_on_the_same_shared_memory(
graph1.edge_attributes[name], graph2.edge_attributes[name]
)
assert graph1.metadata is None and graph2.metadata is None
assert graph1.node_type_offset is None and graph2.node_type_offset is None
assert graph1.type_per_edge is None and graph2.type_per_edge is None
......@@ -847,8 +864,9 @@ def test_homo_graph_on_shared_memory(num_nodes, num_edges):
"num_nodes, num_edges", [(1, 1), (100, 1), (10, 50), (1000, 50000)]
)
@pytest.mark.parametrize("num_ntypes, num_etypes", [(1, 1), (3, 5), (100, 1)])
@pytest.mark.parametrize("test_edge_attrs", [True, False])
def test_hetero_graph_on_shared_memory(
num_nodes, num_edges, num_ntypes, num_etypes
num_nodes, num_edges, num_ntypes, num_etypes, test_edge_attrs
):
(
csc_indptr,
......@@ -857,8 +875,21 @@ def test_hetero_graph_on_shared_memory(
type_per_edge,
metadata,
) = gbt.random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes)
if test_edge_attrs:
edge_attributes = {
"A1": torch.randn(num_edges),
"A2": torch.randn(num_edges),
}
else:
edge_attributes = None
graph = gb.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
csc_indptr,
indices,
node_type_offset,
type_per_edge,
edge_attributes,
metadata,
)
shm_name = "test_hetero_g"
......@@ -894,6 +925,15 @@ def test_hetero_graph_on_shared_memory(
graph1.type_per_edge, graph2.type_per_edge
)
if test_edge_attrs:
for name, edge_attr in edge_attributes.items():
assert name in graph1.edge_attributes
assert name in graph2.edge_attributes
assert torch.equal(graph1.edge_attributes[name], edge_attr)
check_tensors_on_the_same_shared_memory(
graph1.edge_attributes[name], graph2.edge_attributes[name]
)
assert metadata.node_type_to_id == graph1.metadata.node_type_to_id
assert metadata.edge_type_to_id == graph1.metadata.edge_type_to_id
assert metadata.node_type_to_id == graph2.metadata.node_type_to_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment