/** * Copyright (c) 2023 by Contributors * Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek) * @file cuda/gpu_cache.h * @brief Header file of HugeCTR gpu_cache wrapper. */ #ifndef GRAPHBOLT_GPU_CACHE_H_ #define GRAPHBOLT_GPU_CACHE_H_ #include #include #include #include namespace graphbolt { namespace cuda { class GpuCache : public torch::CustomClassHolder { using key_t = long long; constexpr static int set_associativity = 2; constexpr static int WARP_SIZE = 64; constexpr static int bucket_size = WARP_SIZE * set_associativity; using gpu_cache_t = ::gpu_cache::gpu_cache< key_t, uint64_t, std::numeric_limits::max(), set_associativity, WARP_SIZE>; public: /** * @brief Constructor for the GpuCache struct. * * @param shape The shape of the GPU cache. * @param dtype The datatype of items to be stored. */ GpuCache(const std::vector& shape, torch::ScalarType dtype); GpuCache() = default; std::tuple Query( torch::Tensor keys); void Replace(torch::Tensor keys, torch::Tensor values); static c10::intrusive_ptr Create( const std::vector& shape, torch::ScalarType dtype); private: std::vector shape_; torch::ScalarType dtype_; std::unique_ptr cache_; int64_t num_bytes_; int64_t num_float_feats_; torch::DeviceIndex device_id_; }; // The cu file in HugeCTR gpu cache uses unsigned int and long long. // Changing to int64_t results in a mismatch of template arguments. static_assert( sizeof(long long) == sizeof(int64_t), "long long and int64_t needs to have the same size."); // NOLINT } // namespace cuda } // namespace graphbolt #endif // GRAPHBOLT_GPU_CACHE_H_