Unverified Commit d5b03bcb authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] GPUCache performance fix. (#7073)

parent 85683869
...@@ -43,20 +43,19 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> GpuCache::Query( ...@@ -43,20 +43,19 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> GpuCache::Query(
torch::empty(keys.size(0), keys.options().dtype(torch::kLong)); torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
auto missing_keys = auto missing_keys =
torch::empty(keys.size(0), keys.options().dtype(torch::kLong)); torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
cuda::CopyScalar<size_t> missing_len; auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream(); auto missing_len_device = allocator.AllocateStorage<size_t>(1);
cache_->Query( cache_->Query(
reinterpret_cast<const key_t *>(keys.data_ptr()), keys.size(0), reinterpret_cast<const key_t *>(keys.data_ptr()), keys.size(0),
values.data_ptr<float>(), values.data_ptr<float>(),
reinterpret_cast<uint64_t *>(missing_index.data_ptr()), reinterpret_cast<uint64_t *>(missing_index.data_ptr()),
reinterpret_cast<key_t *>(missing_keys.data_ptr()), missing_len.get(), reinterpret_cast<key_t *>(missing_keys.data_ptr()),
stream); missing_len_device.get(), cuda::GetCurrentStream());
values = values.view(torch::kByte) values = values.view(torch::kByte)
.slice(1, 0, num_bytes_) .slice(1, 0, num_bytes_)
.view(dtype_) .view(dtype_)
.view(shape_); .view(shape_);
// To safely read missing_len, we synchronize cuda::CopyScalar<size_t> missing_len(missing_len_device.get());
stream.synchronize();
missing_index = missing_index.slice(0, 0, static_cast<size_t>(missing_len)); missing_index = missing_index.slice(0, 0, static_cast<size_t>(missing_len));
missing_keys = missing_keys.slice(0, 0, static_cast<size_t>(missing_len)); missing_keys = missing_keys.slice(0, 0, static_cast<size_t>(missing_len));
return std::make_tuple(values, missing_index, missing_keys); return std::make_tuple(values, missing_index, missing_keys);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment