/** * Copyright (c) 2023 by Contributors * Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek) * @file cuda/gpu_cache.cu * @brief GPUCache implementation on CUDA. */ #include #include "./common.h" #include "./gpu_cache.h" namespace graphbolt { namespace cuda { GpuCache::GpuCache(const std::vector &shape, torch::ScalarType dtype) { TORCH_CHECK(shape.size() >= 2, "Shape must at least have 2 dimensions."); const auto num_items = shape[0]; const int64_t num_feats = std::accumulate(shape.begin() + 1, shape.end(), 1ll, std::multiplies<>()); const int element_size = torch::empty(1, torch::TensorOptions().dtype(dtype)).element_size(); num_bytes_ = num_feats * element_size; num_float_feats_ = (num_bytes_ + sizeof(float) - 1) / sizeof(float); cache_ = std::make_unique( (num_items + bucket_size - 1) / bucket_size, num_float_feats_); shape_ = shape; shape_[0] = -1; dtype_ = dtype; device_id_ = cuda::GetCurrentStream().device_index(); } std::tuple GpuCache::Query( torch::Tensor keys) { TORCH_CHECK(keys.device().is_cuda(), "Keys should be on a CUDA device."); TORCH_CHECK( keys.device().index() == device_id_, "Keys should be on the correct CUDA device."); TORCH_CHECK(keys.sizes().size() == 1, "Keys should be a 1D tensor."); keys = keys.to(torch::kLong); auto values = torch::empty( {keys.size(0), num_float_feats_}, keys.options().dtype(torch::kFloat)); auto missing_index = torch::empty(keys.size(0), keys.options().dtype(torch::kLong)); auto missing_keys = torch::empty(keys.size(0), keys.options().dtype(torch::kLong)); auto allocator = cuda::GetAllocator(); auto missing_len_device = allocator.AllocateStorage(1); cache_->Query( reinterpret_cast(keys.data_ptr()), keys.size(0), values.data_ptr(), reinterpret_cast(missing_index.data_ptr()), reinterpret_cast(missing_keys.data_ptr()), missing_len_device.get(), cuda::GetCurrentStream()); values = values.view(torch::kByte) .slice(1, 0, num_bytes_) .view(dtype_) .view(shape_); cuda::CopyScalar missing_len(missing_len_device.get()); missing_index = missing_index.slice(0, 0, static_cast(missing_len)); missing_keys = missing_keys.slice(0, 0, static_cast(missing_len)); return std::make_tuple(values, missing_index, missing_keys); } void GpuCache::Replace(torch::Tensor keys, torch::Tensor values) { TORCH_CHECK(keys.device().is_cuda(), "Keys should be on a CUDA device."); TORCH_CHECK( keys.device().index() == device_id_, "Keys should be on the correct CUDA device."); TORCH_CHECK(values.device().is_cuda(), "Keys should be on a CUDA device."); TORCH_CHECK( values.device().index() == device_id_, "Values should be on the correct CUDA device."); TORCH_CHECK( keys.size(0) == values.size(0), "The first dimensions of keys and values must match."); TORCH_CHECK( std::equal(shape_.begin() + 1, shape_.end(), values.sizes().begin() + 1), "Values should have the correct dimensions."); TORCH_CHECK( values.scalar_type() == dtype_, "Values should have the correct dtype."); if (keys.numel() == 0) return; keys = keys.to(torch::kLong); torch::Tensor float_values; if (num_bytes_ % sizeof(float) != 0) { float_values = torch::empty( {values.size(0), num_float_feats_}, values.options().dtype(torch::kFloat)); float_values.view(torch::kByte) .slice(1, 0, num_bytes_) .copy_(values.view(torch::kByte).view({values.size(0), -1})); } else { float_values = values.view(torch::kByte) .view({values.size(0), -1}) .view(torch::kFloat) .contiguous(); } cache_->Replace( reinterpret_cast(keys.data_ptr()), keys.size(0), float_values.data_ptr(), cuda::GetCurrentStream()); } c10::intrusive_ptr GpuCache::Create( const std::vector &shape, torch::ScalarType dtype) { return c10::make_intrusive(shape, dtype); } } // namespace cuda } // namespace graphbolt