"vscode:/vscode.git/clone" did not exist on "2c04ecb55a121587f6fc6beb6ce1bb593f7bccc1"
Unverified Commit c8b18b79 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[RPC] Support null NDArray in RPC communication (#1653)

parent 8192b10b
......@@ -69,7 +69,10 @@ void StreamWithBuffer::PushNDArray(const NDArray& tensor) {
// If the stream is for remote communication or the data is not stored in
// shared memory, serialize the data content as a buffer.
this->Write<bool>(false);
buffer_list_.emplace_back(tensor, tensor->data, data_byte_size);
// If this is a null ndarray, we will not push it into the underlying buffer_list
if (data_byte_size != 0) {
buffer_list_.emplace_back(tensor, tensor->data, data_byte_size);
}
} else {
CHECK(mem) << "Tried to send non-shared-memroy tensor to local "
"StreamWithBuffer";
......@@ -111,9 +114,15 @@ NDArray StreamWithBuffer::PopNDArray() {
} else {
CHECK(send_to_remote_) << "Invalid attempt to deserialize from raw data "
"pointer with send_to_remote=false";
auto ret = CreateNDArrayFromRawData(shape, dtype, cpu_ctx,
buffer_list_.front().data);
buffer_list_.pop_front();
NDArray ret;
if (ndim == 0 || shape[0] == 0) {
// Mean this is a null ndarray
ret = CreateNDArrayFromRawData(shape, dtype, cpu_ctx, nullptr);
} else {
ret = CreateNDArrayFromRawData(shape, dtype, cpu_ctx,
buffer_list_.front().data);
buffer_list_.pop_front();
}
return ret;
}
#else
......
......@@ -69,6 +69,34 @@ TEST(ZeroCopySerialize, NDArray) {
zc_read_strm.Read(&loadtensor2);
}
TEST(ZeroCopySerialize, ZeroShapeNDArray) {
auto tensor1 = VecToIdArray<int64_t>({6, 6, 5, 7});
auto tensor2 = VecToIdArray<int64_t>({});
auto tensor3 = VecToIdArray<int64_t>({6, 6, 2, 7});
std::vector<NDArray> ndvec;
ndvec.push_back(tensor1);
ndvec.push_back(tensor2);
ndvec.push_back(tensor3);
std::string zerocopy_blob;
StreamWithBuffer zc_write_strm(&zerocopy_blob, true);
zc_write_strm.Write(ndvec);
std::vector<void *> new_ptr_list;
// Use memcpy to mimic remote machine reconstruction
for (auto ptr : zc_write_strm.buffer_list()) {
auto new_ptr = malloc(ptr.size);
memcpy(new_ptr, ptr.data, ptr.size);
new_ptr_list.emplace_back(new_ptr);
}
std::vector<NDArray> ndvec_read;
StreamWithBuffer zc_read_strm(&zerocopy_blob, new_ptr_list);
zc_read_strm.Read(&ndvec_read);
EXPECT_EQ(ndvec_read[1]->ndim, 1);
EXPECT_EQ(ndvec_read[1]->shape[0], 0);
}
TEST(ZeroCopySerialize, SharedMem) {
auto tensor1 = VecToIdArray<int64_t>({1, 2, 5, 3});
DLDataType dtype = {kDLInt, 64, 1};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment