Unverified Commit d5b03bcb authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] GPUCache performance fix. (#7073)

parent 85683869
......@@ -43,20 +43,19 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> GpuCache::Query(
torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
auto missing_keys =
torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
cuda::CopyScalar<size_t> missing_len;
auto stream = cuda::GetCurrentStream();
auto allocator = cuda::GetAllocator();
auto missing_len_device = allocator.AllocateStorage<size_t>(1);
cache_->Query(
reinterpret_cast<const key_t *>(keys.data_ptr()), keys.size(0),
values.data_ptr<float>(),
reinterpret_cast<uint64_t *>(missing_index.data_ptr()),
reinterpret_cast<key_t *>(missing_keys.data_ptr()), missing_len.get(),
stream);
reinterpret_cast<key_t *>(missing_keys.data_ptr()),
missing_len_device.get(), cuda::GetCurrentStream());
values = values.view(torch::kByte)
.slice(1, 0, num_bytes_)
.view(dtype_)
.view(shape_);
// To safely read missing_len, we synchronize
stream.synchronize();
cuda::CopyScalar<size_t> missing_len(missing_len_device.get());
missing_index = missing_index.slice(0, 0, static_cast<size_t>(missing_len));
missing_keys = missing_keys.slice(0, 0, static_cast<size_t>(missing_len));
return std::make_tuple(values, missing_index, missing_keys);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment