Unverified Commit 69a532c1 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Feature] Gpu cache for node and edge data (#4341)


Co-authored-by: default avatarxiny <xiny@nvidia.com>
parent 7ec78bb6
...@@ -39,6 +39,7 @@ include_patterns = [ ...@@ -39,6 +39,7 @@ include_patterns = [
'**/*.cu', '**/*.cu',
] ]
exclude_patterns = [ exclude_patterns = [
'third_party/**',
] ]
init_command = [ init_command = [
'python3', 'python3',
......
...@@ -227,6 +227,21 @@ if((NOT MSVC) AND (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")) ...@@ -227,6 +227,21 @@ if((NOT MSVC) AND (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"))
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--exclude-libs,ALL") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--exclude-libs,ALL")
endif() endif()
# Compile gpu_cache
if(USE_CUDA)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_GPU_CACHE")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_GPU_CACHE")
# Manually build gpu_cache because CMake always builds it as shared
file(GLOB gpu_cache_src
third_party/HugeCTR/gpu_cache/src/nv_gpu_cache.cu
)
cuda_add_library(gpu_cache STATIC ${gpu_cache_src})
target_include_directories(gpu_cache PRIVATE "third_party/HugeCTR/gpu_cache/include")
target_include_directories(dgl PRIVATE "third_party/HugeCTR/gpu_cache/include")
list(APPEND DGL_LINKER_LIBS gpu_cache)
message(STATUS "Build with HugeCTR GPU embedding cache.")
endif(USE_CUDA)
# support PARALLEL_ALGORITHMS # support PARALLEL_ALGORITHMS
if (LIBCXX_ENABLE_PARALLEL_ALGORITHMS) if (LIBCXX_ENABLE_PARALLEL_ALGORITHMS)
add_definitions(-DPARALLEL_ALGORITHMS) add_definitions(-DPARALLEL_ALGORITHMS)
......
...@@ -202,59 +202,6 @@ function(dgl_select_nvcc_arch_flags out_variable) ...@@ -202,59 +202,6 @@ function(dgl_select_nvcc_arch_flags out_variable)
set(${out_variable}_readable ${__nvcc_archs_readable} PARENT_SCOPE) set(${out_variable}_readable ${__nvcc_archs_readable} PARENT_SCOPE)
endfunction() endfunction()
################################################################################################
# Short command for cuda compilation
# Usage:
# dgl_cuda_compile(<objlist_variable> <cuda_files>)
macro(dgl_cuda_compile objlist_variable)
foreach(var CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS_DEBUG)
set(${var}_backup_in_cuda_compile_ "${${var}}")
# we remove /EHa as it generates warnings under windows
string(REPLACE "/EHa" "" ${var} "${${var}}")
endforeach()
if(UNIX OR APPLE)
list(APPEND CUDA_NVCC_FLAGS -Xcompiler -fPIC --std=c++14)
endif()
if(APPLE)
list(APPEND CUDA_NVCC_FLAGS -Xcompiler -Wno-unused-function)
endif()
set(CUDA_NVCC_FLAGS_DEBUG "${CUDA_NVCC_FLAGS_DEBUG} -G")
if(MSVC)
# disable noisy warnings:
# 4819: The file contains a character that cannot be represented in the current code page (number).
list(APPEND CUDA_NVCC_FLAGS -Xcompiler "/wd4819")
foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO)
if(${flag_var} MATCHES "/MD")
string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}")
endif(${flag_var} MATCHES "/MD")
endforeach(flag_var)
endif()
# If the build system is a container, make sure the nvcc intermediate files
# go into the build output area rather than in /tmp, which may run out of space
if(IS_CONTAINER_BUILD)
set(CUDA_NVCC_INTERMEDIATE_DIR "${CMAKE_CURRENT_BINARY_DIR}")
message(STATUS "Container build enabled, so nvcc intermediate files in: ${CUDA_NVCC_INTERMEDIATE_DIR}")
list(APPEND CUDA_NVCC_FLAGS "--keep --keep-dir ${CUDA_NVCC_INTERMEDIATE_DIR}")
endif()
cuda_compile(cuda_objcs ${ARGN})
foreach(var CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS_DEBUG)
set(${var} "${${var}_backup_in_cuda_compile_}")
unset(${var}_backup_in_cuda_compile_)
endforeach()
set(${objlist_variable} ${cuda_objcs})
endmacro()
################################################################################################ ################################################################################################
# Config cuda compilation. # Config cuda compilation.
# Usage: # Usage:
...@@ -289,7 +236,7 @@ macro(dgl_config_cuda out_variable) ...@@ -289,7 +236,7 @@ macro(dgl_config_cuda out_variable)
set(CUDA_PROPAGATE_HOST_FLAGS OFF) set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# 0. Add host flags # 0. Add host flags
message(STATUS "${CMAKE_CXX_FLAGS}") message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
string(REGEX REPLACE "[ \t\n\r]" "," CXX_HOST_FLAGS "${CMAKE_CXX_FLAGS}") string(REGEX REPLACE "[ \t\n\r]" "," CXX_HOST_FLAGS "${CMAKE_CXX_FLAGS}")
if(MSVC AND NOT USE_MSVC_MT) if(MSVC AND NOT USE_MSVC_MT)
string(CONCAT CXX_HOST_FLAGS ${CXX_HOST_FLAGS} ",/MD") string(CONCAT CXX_HOST_FLAGS ${CXX_HOST_FLAGS} ",/MD")
...@@ -303,14 +250,7 @@ macro(dgl_config_cuda out_variable) ...@@ -303,14 +250,7 @@ macro(dgl_config_cuda out_variable)
# 2. flags in third_party/moderngpu # 2. flags in third_party/moderngpu
list(APPEND CUDA_NVCC_FLAGS "--expt-extended-lambda;-Wno-deprecated-declarations") list(APPEND CUDA_NVCC_FLAGS "--expt-extended-lambda;-Wno-deprecated-declarations")
message(STATUS "CUDA_NVCC_FLAGS: ${CUDA_NVCC_FLAGS}")
# 3. CUDA 11 requires c++14 by default
include(CheckCXXCompilerFlag)
check_cxx_compiler_flag("-std=c++14" SUPPORT_CXX14)
string(REPLACE "-std=c++11" "" CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS}")
list(APPEND CUDA_NVCC_FLAGS "-std=c++14")
message(STATUS "CUDA flags: ${CUDA_NVCC_FLAGS}")
list(APPEND DGL_LINKER_LIBS list(APPEND DGL_LINKER_LIBS
${CUDA_CUDART_LIBRARY} ${CUDA_CUDART_LIBRARY}
......
""" CUDA wrappers """ """ CUDA wrappers """
from .. import backend as F from .. import backend as F
from .gpu_cache import GPUCache
if F.get_preferred_backend() == "pytorch": if F.get_preferred_backend() == "pytorch":
from . import nccl from . import nccl
"""API wrapping HugeCTR gpu_cache."""
# Copyright (c) 2022, NVIDIA Corporation
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @file gpu_cache.py
# @brief API for managing a GPU Cache
from .. import backend as F
from .._ffi.function import _init_api
class GPUCache(object):
"""High-level wrapper for GPU embedding cache"""
def __init__(self, num_items, num_feats, idtype=F.int64):
assert idtype in [F.int32, F.int64]
self._cache = _CAPI_DGLGpuCacheCreate(
num_items, num_feats, 32 if idtype == F.int32 else 64
)
self.idtype = idtype
self.total_miss = 0
self.total_queries = 0
def query(self, keys):
"""Queries the GPU cache.
Parameters
----------
keys : Tensor
The keys to query the GPU cache with.
Returns
-------
tuple(Tensor, Tensor, Tensor)
A tuple containing (values, missing_indices, missing_keys) where
values[missing_indices] corresponds to cache misses that should be
filled by quering another source with missing_keys.
"""
self.total_queries += keys.shape[0]
keys = F.astype(keys, self.idtype)
values, missing_index, missing_keys = _CAPI_DGLGpuCacheQuery(
self._cache, F.to_dgl_nd(keys)
)
self.total_miss += missing_keys.shape[0]
return (
F.from_dgl_nd(values),
F.from_dgl_nd(missing_index),
F.from_dgl_nd(missing_keys),
)
def replace(self, keys, values):
"""Inserts key-value pairs into the GPU cache using the Least-Recently
Used (LRU) algorithm to remove old key-value pairs if it is full.
Parameters
----------
keys: Tensor
The keys to insert to the GPU cache.
values: Tensor
The values to insert to the GPU cache.
"""
keys = F.astype(keys, self.idtype)
values = F.astype(values, F.float32)
_CAPI_DGLGpuCacheReplace(
self._cache, F.to_dgl_nd(keys), F.to_dgl_nd(values)
)
@property
def miss_rate(self):
"""Returns the cache miss rate since creation."""
return self.total_miss / self.total_queries
_init_api("dgl.cuda", __name__)
/*!
* Copyright (c) 2022 by Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file gpu_cache.cu
* \brief Implementation of wrapper HugeCTR gpu_cache routines.
*/
#ifndef DGL_RUNTIME_CUDA_GPU_CACHE_H_
#define DGL_RUNTIME_CUDA_GPU_CACHE_H_
#include <cuda_runtime.h>
#include <dgl/array.h>
#include <dgl/aten/array_ops.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/object.h>
#include <dgl/runtime/registry.h>
#include <nv_gpu_cache.hpp>
#include "../../runtime/cuda/cuda_common.h"
namespace dgl {
namespace runtime {
namespace cuda {
template <typename key_t>
class GpuCache : public runtime::Object {
constexpr static int set_associativity = 2;
constexpr static int WARP_SIZE = 32;
constexpr static int bucket_size = WARP_SIZE * set_associativity;
using gpu_cache_t = gpu_cache::gpu_cache<
key_t, uint64_t, std::numeric_limits<key_t>::max(), set_associativity,
WARP_SIZE>;
public:
static constexpr const char *_type_key =
sizeof(key_t) == 4 ? "cuda.GpuCache32" : "cuda.GpuCache64";
DGL_DECLARE_OBJECT_TYPE_INFO(GpuCache, Object);
GpuCache(size_t num_items, size_t num_feats)
: num_feats(num_feats),
cache(std::make_unique<gpu_cache_t>(
(num_items + bucket_size - 1) / bucket_size, num_feats)) {
CUDA_CALL(cudaGetDevice(&cuda_device));
}
std::tuple<NDArray, IdArray, IdArray> Query(IdArray keys) {
const auto &ctx = keys->ctx;
cudaStream_t stream = dgl::runtime::getCurrentCUDAStream();
auto device = dgl::runtime::DeviceAPI::Get(ctx);
CHECK_EQ(ctx.device_type, kDGLCUDA)
<< "The keys should be on a CUDA device";
CHECK_EQ(ctx.device_id, cuda_device)
<< "The keys should be on the correct CUDA device";
CHECK_EQ(keys->ndim, 1)
<< "The tensor of requested indices must be of dimension one.";
NDArray values = NDArray::Empty(
{keys->shape[0], (int64_t)num_feats}, DGLDataType{kDGLFloat, 32, 1},
ctx);
IdArray missing_index = aten::NewIdArray(keys->shape[0], ctx, 64);
IdArray missing_keys =
aten::NewIdArray(keys->shape[0], ctx, sizeof(key_t) * 8);
size_t *missing_len =
static_cast<size_t *>(device->AllocWorkspace(ctx, sizeof(size_t)));
cache->Query(
static_cast<const key_t *>(keys->data), keys->shape[0],
static_cast<float *>(values->data),
static_cast<uint64_t *>(missing_index->data),
static_cast<key_t *>(missing_keys->data), missing_len, stream);
size_t missing_len_host;
device->CopyDataFromTo(
missing_len, 0, &missing_len_host, 0, sizeof(missing_len_host), ctx,
DGLContext{kDGLCPU, 0}, keys->dtype);
device->FreeWorkspace(ctx, missing_len);
missing_index = missing_index.CreateView(
{(int64_t)missing_len_host}, missing_index->dtype);
missing_keys =
missing_keys.CreateView({(int64_t)missing_len_host}, keys->dtype);
return std::make_tuple(values, missing_index, missing_keys);
}
void Replace(IdArray keys, NDArray values) {
cudaStream_t stream = dgl::runtime::getCurrentCUDAStream();
CHECK_EQ(keys->ctx.device_type, kDGLCUDA)
<< "The keys should be on a CUDA device";
CHECK_EQ(keys->ctx.device_id, cuda_device)
<< "The keys should be on the correct CUDA device";
CHECK_EQ(values->ctx.device_type, kDGLCUDA)
<< "The values should be on a CUDA device";
CHECK_EQ(values->ctx.device_id, cuda_device)
<< "The values should be on the correct CUDA device";
CHECK_EQ(keys->shape[0], values->shape[0])
<< "First dimensions of keys and values must match";
CHECK_EQ(values->shape[1], num_feats) << "Embedding dimension must match";
cache->Replace(
static_cast<const key_t *>(keys->data), keys->shape[0],
static_cast<const float *>(values->data), stream);
}
private:
size_t num_feats;
std::unique_ptr<gpu_cache_t> cache;
int cuda_device;
};
static_assert(sizeof(unsigned int) == 4);
DGL_DEFINE_OBJECT_REF(GpuCacheRef32, GpuCache<unsigned int>);
// The cu file in HugeCTR gpu cache uses unsigned int and long long.
// Changing to int64_t results in a mismatch of template arguments.
static_assert(sizeof(long long) == 8); // NOLINT
DGL_DEFINE_OBJECT_REF(GpuCacheRef64, GpuCache<long long>); // NOLINT
/* CAPI **********************************************************************/
using namespace dgl::runtime;
DGL_REGISTER_GLOBAL("cuda._CAPI_DGLGpuCacheCreate")
.set_body([](DGLArgs args, DGLRetValue *rv) {
const size_t num_items = args[0];
const size_t num_feats = args[1];
const int num_bits = args[2];
if (num_bits == 32)
*rv = GpuCacheRef32(
std::make_shared<GpuCache<unsigned int>>(num_items, num_feats));
else
*rv = GpuCacheRef64(std::make_shared<GpuCache<long long>>( // NOLINT
num_items, num_feats));
});
DGL_REGISTER_GLOBAL("cuda._CAPI_DGLGpuCacheQuery")
.set_body([](DGLArgs args, DGLRetValue *rv) {
IdArray keys = args[1];
List<ObjectRef> ret;
if (keys->dtype.bits == 32) {
GpuCacheRef32 cache = args[0];
auto result = cache->Query(keys);
ret.push_back(Value(MakeValue(std::get<0>(result))));
ret.push_back(Value(MakeValue(std::get<1>(result))));
ret.push_back(Value(MakeValue(std::get<2>(result))));
} else {
GpuCacheRef64 cache = args[0];
auto result = cache->Query(keys);
ret.push_back(Value(MakeValue(std::get<0>(result))));
ret.push_back(Value(MakeValue(std::get<1>(result))));
ret.push_back(Value(MakeValue(std::get<2>(result))));
}
*rv = ret;
});
DGL_REGISTER_GLOBAL("cuda._CAPI_DGLGpuCacheReplace")
.set_body([](DGLArgs args, DGLRetValue *rv) {
IdArray keys = args[1];
NDArray values = args[2];
if (keys->dtype.bits == 32) {
GpuCacheRef32 cache = args[0];
cache->Replace(keys, values);
} else {
GpuCacheRef64 cache = args[0];
cache->Replace(keys, values);
}
*rv = List<ObjectRef>{};
});
} // namespace cuda
} // namespace runtime
} // namespace dgl
#endif
#
# Copyright (c) 2022 by Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
import backend as F
import dgl
from utils import parametrize_idtype
D = 5
def generate_graph(idtype, grad=False, add_data=True):
g = dgl.DGLGraph().to(F.ctx(), dtype=idtype)
g.add_nodes(10)
u, v = [], []
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
u.append(0)
v.append(i)
u.append(i)
v.append(9)
# add a back flow from 9 to 0
u.append(9)
v.append(0)
g.add_edges(u, v)
if add_data:
ncol = F.randn((10, D))
ecol = F.randn((17, D))
if grad:
ncol = F.attach_grad(ncol)
ecol = F.attach_grad(ecol)
g.ndata["h"] = ncol
g.edata["l"] = ecol
return g
@unittest.skipIf(not F.gpu_ctx(), reason="only necessary with GPU")
@parametrize_idtype
def test_gpu_cache(idtype):
g = generate_graph(idtype)
cache = dgl.cuda.GPUCache(5, D, idtype)
h = g.ndata["h"]
t = 5
keys = F.arange(0, t, dtype=idtype)
values, m_idx, m_keys = cache.query(keys)
m_values = h[F.tensor(m_keys, F.int64)]
values[F.tensor(m_idx, F.int64)] = m_values
cache.replace(m_keys, m_values)
keys = F.arange(3, 8, dtype=idtype)
values, m_idx, m_keys = cache.query(keys)
assert m_keys.shape[0] == 3 and m_idx.shape[0] == 3
m_values = h[F.tensor(m_keys, F.int64)]
values[F.tensor(m_idx, F.int64)] = m_values
assert (values != h[F.tensor(keys, F.int64)]).sum().item() == 0
cache.replace(m_keys, m_values)
if __name__ == "__main__":
test_gpu_cache(F.int64)
test_gpu_cache(F.int32)
# GPU Embedding Cache
This project implements an embedding cache on GPU memory that is designed for CTR inference and training workload.
The cache stores the hot pairs, (embedding id, embedding vectors), on GPU memory.
Storing the data on GPU memory reduces the traffic to the parameter server when performing embedding table lookup.
The cache is designed for CTR inference and training, it has following features and restrictions:
* All the backup memory-side operations are performed by the parameter server.
These operations include prefetching, latency hiding, and so on.
* This is a single-GPU design.
Each cache belongs to one GPU.
* The cache is thread-safe: multiple workers, CPU threads, can concurrently call the API of a single cache object with well-defined behavior.
* The cache implements a least recently used (LRU) replacement algorithm so that it caches the most recently queried embeddings.
* The embeddings stored inside the cache are unique: there are no duplicated embedding IDs in the cache.
## Project Structure
This project is a stand-alone module in HugeCTR project.
The root folder of this project is the `gpu_cache` folder under the HugeCTR root directory.
The `include` folder contains the headers for the cache library and the `src` folder contains the implementations and Makefile for the cache library.
The `test` folder contains a test that tests the correctness and performance of the GPU embedding cache.
The test also acts as sample code that shows how to use the cache.
The `nv_gpu_cache.hpp` file contains the definition of the main class, `gpu_cache`, that implements the GPU embedding cache.
The `nv_gpu_cache.cu` file contains the implementation.
As a module of HugeCTR, this project is built with and used by the HugeCTR project.
## Supported Data Types
* The cache supports 32 and 64-bit scalar integer types for the key (embedding ID) type.
For example, the data type declarations `unsigned int` and `long long` match these integer types.
* The cache supports a vector of floats for the value (embedding vector) type.
* You need to specify an empty key to indicate the empty bucket.
Do not use an empty key to represent any real key.
* Refer to the instantiation code at the end of the `nv_gpu_cache.cu` file for template parameters.
## Requirements
* NVIDIA GPU >= Volta (SM 70).
* CUDA environment >= 11.0.
* (Optional) libcu++ library >= 1.1.0.
The CUDA Toolkit 11.0 (Early Access) and above meets the required library version.
Using the libcu++ library provides better performance and more precisely-defined behavior.
You can enable libcu++ library by defining the `LIBCUDACXX_VERSION` macro when you compile.
Otherwise, the libcu++ library is not enabled.
* The default building option for HugeCTR is to disable the libcu++ library.
## Usage Overview
```c++
template<typename key_type,
typename ref_counter_type,
key_type empty_key,
int set_associativity,
int warp_size,
typename set_hasher = MurmurHash3_32<key_type>,
typename slab_hasher = Mod_Hash<key_type, size_t>>
class gpu_cache{
public:
//Ctor
gpu_cache(const size_t capacity_in_set, const size_t embedding_vec_size);
//Dtor
~gpu_cache();
// Query API, i.e. A single read from the cache
void Query(const key_type* d_keys,
const size_t len,
float* d_values,
uint64_t* d_missing_index,
key_type* d_missing_keys,
size_t* d_missing_len,
cudaStream_t stream,
const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO);
// Replace API, i.e. Follow the Query API to update the content of the cache to Most Recent
void Replace(const key_type* d_keys,
const size_t len,
const float* d_values,
cudaStream_t stream,
const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO);
// Update API, i.e. update the embeddings which exist in the cache
void Update(const key_type* d_keys,
const size_t len,
const float* d_values,
cudaStream_t stream,
const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO);
// Dump API, i.e. dump some slabsets' keys from the cache
void Dump(key_type* d_keys,
size_t* d_dump_counter,
const size_t start_set_index,
const size_t end_set_index,
cudaStream_t stream);
};
```
## API
`Constructor`
To create a new embedding cache, you need to provide the following:
* Template parameters:
+ key_type: the data type of embedding ID.
+ ref_counter_type: the data type of the internal counter. This data type should be 64bit unsigned integer(i.e. uint64_t), 32bit integer has the risk of overflow.
+ empty_key: the key value indicate for empty bucket(i.e. The empty key), user should never use empty key value to represent any real keys.
+ set_associativity: the hyper-parameter indicates how many slabs per cache set.(See `Performance hint` session below)
+ warp_size: the hyper-parameter indicates how many [key, value] pairs per slab. Acceptable value includes 1/2/4/8/16/32.(See `Performance hint` session below)
+ For other template parameters just use the default value.
* Parameters:
+ capacity_in_set: # of cache set in the embedding cache. So the total capacity of the embedding cache is `warp_size * set_associativity * capacity_in_set` [key, value] pairs.
+ embedding_vec_size: # of float per a embedding vector.
* The host thread will wait for the GPU kernels to complete before returning from the API, thus this API is synchronous with CPU thread. When returned, the initialization process of the cache is already done.
* The embedding cache will be created on the GPU where user call the constructor. Thus, user should set the host thread to the target CUDA device before creating the embedding cache. All resources(i.e. device-side buffers, CUDA streams) used later for this embedding cache should be allocated on the same CUDA device as the embedding cache.
* The constructor can be called only once, thus is not thread-safe.
`Destructor`
* The destructor clean up the embedding cache. This API should be called only once when user need to delete the embedding cache object, thus is not thread-safe.
`Query`
* Search `len` elements from device-side buffers `d_keys` in the cache and return the result in device-side buffer `d_values` if a key is hit in the cache.
* If a key is missing, the missing key and its index in the `d_keys` buffer will be returned in device-side buffers `d_missing_keys` and `d_missing_index`. The # of missing key will be return in device-side buffer `d_missing_len`. For simplicity, these buffers should have the same length as `d_keys` to avoid out-of-bound access.
* The GPU kernels will be launched in `stream` CUDA stream.
* The host thread will return from the API immediately after the kernels are launched, thus this API is Asynchronous with CPU thread.
* The keys to be queried in the `d_keys` buffer can have duplication. In this case, user will get duplicated returned values or missing information.
* This API is thread-safe and can be called concurrently with other APIs.
* For hyper-parameter `task_per_warp_tile`, see `Performance hint` session below.
`Replace`
* The API will replace `len` [key, value] pairs listed in `d_keys` and `d_values` into the embedding cache using the LRU replacement algorithm.
* The GPU kernels will be launched in `stream` CUDA stream.
* The host thread will return from the API immediately after the kernels are launched, thus this API is Asynchronous with CPU thread.
* The keys to be replaced in the `d_keys` buffer can have duplication and can be already stored inside the cache. In these cases, the cache will detect any possible duplication and maintain the uniqueness of all the [key ,value] pairs stored in the cache.
* This API is thread-safe and can be called concurrently with other APIs.
* This API will first try to insert the [key, value] pairs into the cache if there is any empty slot. If the cache is full, it will do the replacement.
* For hyper-parameter `task_per_warp_tile`, see `Performance hint` session below.
`Update`
* The API will search for `len` keys listed in `d_keys` buffer within the cache. If a key is found in the cache, this API will update the value associated with the key to the corresponding values provided in `d_values` buffer. If a key is not found in the cache, this API will do nothing to this key.
* The GPU kernels will be launched in `stream` CUDA stream.
* The host thread will return from the API immediately after the kernels are launched, thus this API is Asynchronous with CPU thread.
* If the keys to be updated in the `d_keys` buffer have duplication, all values associated with this key in the `d_values` buffer will be updated to the cache atomically. The final result depends on the order of updating the value.
* This API is thread-safe and can be called concurrently with other APIs.
* For hyper-parameter `task_per_warp_tile`, see `Performance hint` session below.
`Dump`
* The API will dump all the keys stored in [`start_set_index`, `end_set_index`) cache sets to `d_keys` buffer as a linear array(the key order is not guaranteed). The total # of keys dumped will be reported in `d_dump_counter` variable.
* The GPU kernels will be launched in `stream` CUDA stream.
* The host thread will return from the API immediately after the kernels are launched, thus this API is Asynchronous with CPU thread.
* This API is thread-safe and can be called concurrently with other APIs.
## More Information
* The detailed introduction of the GPU embedding cache data structure is presented at GTC China 2020: https://on-demand-gtc.gputechconf.com/gtcnew/sessionview.php?sessionName=cns20626-%e4%bd%bf%e7%94%a8+gpu+embedding+cache+%e5%8a%a0%e9%80%9f+ctr+%e6%8e%a8%e7%90%86%e8%bf%87%e7%a8%8b
* The `test` folder contains a example of using the GPU embedding cache.
* This project is used by `embedding_cache` class in `HugeCTR/include/inference/embedding_cache.hpp` which can be used as an example.
## Performance Hint
* The hyper-parameter `warp_size` should be keep as 32 by default. When the length for Query or Replace operations is small(~1-50k), user can choose smaller warp_size and increase the total # of cache set(while maintaining the same cache size) to increase the parallelism and improve the performance.
* The hyper-parameter `set_associativity` is critical to performance:
+ If set too small, may cause load imbalance between different cache sets(lower down the effective capacity of the cache, lower down the hit rate). To prevent this, the embedding cache uses a very random hash function to hash the keys to different cache set, thus will achieve load balance statistically. However, larger cache set will tends to have better load balance.
+ If set too large, the searching space for a single key will be very large. The performance of the embedding cache API will drop dramatically. Also, each set will be accessed exclusively, thus the more cache sets the higher parallelism can be achieved.
+ Recommend setting `set_associativity` to 2 or 4.
* The runtime hyper-parameter `task_per_warp_tile` is set to 1 as default parameter, thus users don't need to change their code to accommodate this interface change. This hyper-parameter determines how many keys are been queried/replaced/updated by a single warp tile. The acceptable value is between [1, `warp_size`]. For small to medium size operations to the cache, less task per warp tile can increase the total # of warp tiles running concurrently on the GPU chip, thus can bring significant performance improvement. For large size operations to the cache, the increased # of warp tile will not bring any performance improvement(even a little regression on the performance, ~5%). User can choose the value for this parameter based on the value of `len` parameter.
* The GPU is designed for optimizing throughput. Always try to batch up the inference task and try to have larger `query_size`.
* As the APIs of the embedding cache is asynchronous with host threads. Try to optimize the E2E inference pipeline by overlapping asynchronous tasks on GPU or between CPU and GPU. For example, after retrieving the missing values from the parameter server, user can combine the missing values with the hit values and do the rest of inference pipeline at the same time with the `Replace` API. Replacement is not necessarily happens together with Query all the time, user can do query multiple times then do a replacement if the hit rate is acceptable.
* Try different cache capacity and evaluate the hit rate. If the capacity of embedding cache can be larger than actual embedding footprint, the hit rate can be as high as 99%+.
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <nv_util.h>
#define TASK_PER_WARP_TILE_MACRO 1
namespace gpu_cache {
///////////////////////////////////////////////////////////////////////////////////////////////////
// GPU Cache API
template <typename key_type>
class gpu_cache_api {
public:
virtual ~gpu_cache_api() noexcept(false) {}
// Query API, i.e. A single read from the cache
virtual void Query(const key_type* d_keys, const size_t len, float* d_values,
uint64_t* d_missing_index, key_type* d_missing_keys, size_t* d_missing_len,
cudaStream_t stream,
const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO) = 0;
// Replace API, i.e. Follow the Query API to update the content of the cache to Most Recent
virtual void Replace(const key_type* d_keys, const size_t len, const float* d_values,
cudaStream_t stream,
const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO) = 0;
// Update API, i.e. update the embeddings which exist in the cache
virtual void Update(const key_type* d_keys, const size_t len, const float* d_values,
cudaStream_t stream,
const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO) = 0;
// Dump API, i.e. dump some slabsets' keys from the cache
virtual void Dump(key_type* d_keys, size_t* d_dump_counter, const size_t start_set_index,
const size_t end_set_index, cudaStream_t stream) = 0;
};
} // namespace gpu_cache
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
// MurmurHash3_32 implementation from
// https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp
//-----------------------------------------------------------------------------
// MurmurHash3 was written by Austin Appleby, and is placed in the public
// domain. The author hereby disclaims copyright to this source code.
// Note - The x86 and x64 versions do _not_ produce the same results, as the
// algorithms are optimized for their respective platforms. You can still
// compile and run any of them on any platform, but your performance with the
// non-native version will be less than optimal.
template <typename Key, uint32_t m_seed = 0>
struct MurmurHash3_32 {
using argument_type = Key;
using result_type = uint32_t;
/*__forceinline__
__host__ __device__
MurmurHash3_32() : m_seed( 0 ) {}*/
__forceinline__ __host__ __device__ static uint32_t rotl32(uint32_t x, int8_t r) {
return (x << r) | (x >> (32 - r));
}
__forceinline__ __host__ __device__ static uint32_t fmix32(uint32_t h) {
h ^= h >> 16;
h *= 0x85ebca6b;
h ^= h >> 13;
h *= 0xc2b2ae35;
h ^= h >> 16;
return h;
}
/* --------------------------------------------------------------------------*/
/**
* @Synopsis Combines two hash values into a new single hash value. Called
* repeatedly to create a hash value from several variables.
* Taken from the Boost hash_combine function
* https://www.boost.org/doc/libs/1_35_0/doc/html/boost/hash_combine_id241013.html
*
* @Param lhs The first hash value to combine
* @Param rhs The second hash value to combine
*
* @Returns A hash value that intelligently combines the lhs and rhs hash values
*/
/* ----------------------------------------------------------------------------*/
__host__ __device__ static result_type hash_combine(result_type lhs, result_type rhs) {
result_type combined{lhs};
combined ^= rhs + 0x9e3779b9 + (combined << 6) + (combined >> 2);
return combined;
}
__forceinline__ __host__ __device__ static result_type hash(const Key& key) {
constexpr int len = sizeof(argument_type);
const uint8_t* const data = (const uint8_t*)&key;
constexpr int nblocks = len / 4;
uint32_t h1 = m_seed;
constexpr uint32_t c1 = 0xcc9e2d51;
constexpr uint32_t c2 = 0x1b873593;
//----------
// body
const uint32_t* const blocks = (const uint32_t*)(data + nblocks * 4);
for (int i = -nblocks; i; i++) {
uint32_t k1 = blocks[i]; // getblock32(blocks,i);
k1 *= c1;
k1 = rotl32(k1, 15);
k1 *= c2;
h1 ^= k1;
h1 = rotl32(h1, 13);
h1 = h1 * 5 + 0xe6546b64;
}
//----------
// tail
const uint8_t* tail = (const uint8_t*)(data + nblocks * 4);
uint32_t k1 = 0;
switch (len & 3) {
case 3:
k1 ^= tail[2] << 16;
case 2:
k1 ^= tail[1] << 8;
case 1:
k1 ^= tail[0];
k1 *= c1;
k1 = rotl32(k1, 15);
k1 *= c2;
h1 ^= k1;
};
//----------
// finalization
h1 ^= len;
h1 = fmix32(h1);
return h1;
}
__host__ __device__ __forceinline__ result_type operator()(const Key& key) const {
return this->hash(key);
}
};
template <typename key_type, typename index_type, index_type result>
struct Fix_Hash {
using result_type = index_type;
__forceinline__ __host__ __device__ static index_type hash(const key_type& key) { return result; }
};
template <typename key_type, typename result_type>
struct Mod_Hash {
__forceinline__ __host__ __device__ static result_type hash(const key_type& key) {
return (result_type)key;
}
};
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <nv_util.h>
#include <cstdio>
#include <hash_functions.cuh>
#include <limits>
#include "gpu_cache_api.hpp"
#ifdef LIBCUDACXX_VERSION
#include <cuda/std/atomic>
#include <cuda/std/semaphore>
#endif
#define SET_ASSOCIATIVITY 2
#define SLAB_SIZE 32
#define TASK_PER_WARP_TILE_MACRO 1
namespace gpu_cache {
// slab for static slab list
template <typename key_type, int warp_size>
struct static_slab {
key_type slab_[warp_size];
};
// Static slablist(slabset) for GPU Cache
template <int set_associativity, typename key_type, int warp_size>
struct slab_set {
static_slab<key_type, warp_size> set_[set_associativity];
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// GPU Cache
template <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,
int warp_size, typename set_hasher = MurmurHash3_32<key_type>,
typename slab_hasher = Mod_Hash<key_type, size_t>>
class gpu_cache : public gpu_cache_api<key_type> {
public:
// Ctor
gpu_cache(const size_t capacity_in_set, const size_t embedding_vec_size);
// Dtor
~gpu_cache();
// Query API, i.e. A single read from the cache
void Query(const key_type* d_keys, const size_t len, float* d_values, uint64_t* d_missing_index,
key_type* d_missing_keys, size_t* d_missing_len, cudaStream_t stream,
const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO) override;
// Replace API, i.e. Follow the Query API to update the content of the cache to Most Recent
void Replace(const key_type* d_keys, const size_t len, const float* d_values, cudaStream_t stream,
const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO) override;
// Update API, i.e. update the embeddings which exist in the cache
void Update(const key_type* d_keys, const size_t len, const float* d_values, cudaStream_t stream,
const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO) override;
// Dump API, i.e. dump some slabsets' keys from the cache
void Dump(key_type* d_keys, size_t* d_dump_counter, const size_t start_set_index,
const size_t end_set_index, cudaStream_t stream) override;
public:
using slabset = slab_set<set_associativity, key_type, warp_size>;
#ifdef LIBCUDACXX_VERSION
using atomic_ref_counter_type = cuda::atomic<ref_counter_type, cuda::thread_scope_device>;
using mutex = cuda::binary_semaphore<cuda::thread_scope_device>;
#endif
private:
static const size_t BLOCK_SIZE_ = 64;
// Cache data
slabset* keys_;
float* vals_;
ref_counter_type* slot_counter_;
// Global counter
#ifdef LIBCUDACXX_VERSION
atomic_ref_counter_type* global_counter_;
#else
ref_counter_type* global_counter_;
#endif
// CUDA device
int dev_;
// Cache capacity
size_t capacity_in_set_;
size_t num_slot_;
// Embedding vector size
size_t embedding_vec_size_;
#ifdef LIBCUDACXX_VERSION
// Array of mutex to protect (sub-)warp-level data structure, each mutex protect 1 slab set
mutex* set_mutex_;
#else
// Array of flag to protect (sub-)warp-level data structure, each flag act as a mutex and protect
// 1 slab set 1 for unlock, 0 for lock
int* set_mutex_;
#endif
};
} // namespace gpu_cache
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_runtime_api.h>
#include <stdexcept>
#include <string>
#define CUDA_CHECK(val) \
{ nv::cuda_check_((val), __FILE__, __LINE__); }
namespace nv {
class CudaException : public std::runtime_error {
public:
CudaException(const std::string& what) : runtime_error(what) {}
};
inline void cuda_check_(cudaError_t val, const char* file, int line) {
if (val != cudaSuccess) {
throw CudaException(std::string(file) + ":" + std::to_string(line) + ": CUDA error " +
std::to_string(val) + ": " + cudaGetErrorString(val));
}
}
class CudaDeviceRestorer {
public:
CudaDeviceRestorer() { CUDA_CHECK(cudaGetDevice(&dev_)); }
~CudaDeviceRestorer() { CUDA_CHECK(cudaSetDevice(dev_)); }
void check_device(int device) const {
if (device != dev_) {
throw std::runtime_error(
std::string(__FILE__) + ":" + std::to_string(__LINE__) +
": Runtime Error: The device id in the context is not consistent with configuration");
}
}
private:
int dev_;
};
inline int get_dev(const void* ptr) {
cudaPointerAttributes attr;
CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr));
int dev = -1;
#if CUDART_VERSION >= 10000
if (attr.type == cudaMemoryTypeDevice)
#else
if (attr.memoryType == cudaMemoryTypeDevice)
#endif
{
dev = attr.device;
}
return dev;
}
inline void switch_to_dev(const void* ptr) {
int dev = get_dev(ptr);
if (dev >= 0) {
CUDA_CHECK(cudaSetDevice(dev));
}
}
} // namespace nv
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <nv_util.h>
#include <hash_functions.cuh>
namespace gpu_cache {
template <typename key_type, typename value_type, unsigned int tile_size = 4,
unsigned int group_size = 16, typename hasher = MurmurHash3_32<key_type>>
class StaticHashTable {
public:
using size_type = uint32_t;
static_assert(sizeof(key_type) <= 8, "sizeof(key_type) cannot be larger than 8 bytes");
static_assert(sizeof(key_type) >= sizeof(size_type),
"sizeof(key_type) cannot be smaller than sizeof(size_type)");
static_assert((group_size & (group_size - 1)) == 0, "group_size must be a power of 2");
static_assert(group_size > 1, "group_size must be larger than 1");
// User can use empty_key as input without affecting correctness,
// since we will handle it inside kernel.
constexpr static key_type empty_key = ~(key_type)0;
constexpr static size_type invalid_slot = ~(size_type)0;
public:
StaticHashTable(size_type capacity, int value_dim = 1, hasher hash = hasher{});
~StaticHashTable();
inline size_type size() const { return size_; }
inline size_type capacity() const { return value_capacity_; }
inline size_type key_capacity() const { return key_capacity_; }
inline size_t memory_usage() const {
size_t keys_bytes = sizeof(key_type) * (key_capacity_ + 1);
size_t indices_bytes = sizeof(size_type) * (key_capacity_ + 1);
size_t values_bytes = sizeof(value_type) * value_capacity_ * value_dim_;
return keys_bytes + indices_bytes + values_bytes;
}
void clear(cudaStream_t stream = 0);
// Note:
// 1. Please make sure the key to be inserted is not duplicated.
// 2. Please make sure the key to be inserted does not exist in the table.
// 3. Please make sure (size() + num_keys) <= capacity().
void insert(const key_type *keys, const value_type *values, size_type num_keys,
cudaStream_t stream = 0);
void lookup(const key_type *keys, value_type *values, int num_keys, value_type default_value = 0,
cudaStream_t stream = 0);
private:
key_type *table_keys_;
size_type *table_indices_;
size_type key_capacity_;
value_type *table_values_;
size_type value_capacity_;
int value_dim_;
size_type size_;
hasher hash_;
};
} // namespace gpu_cache
\ No newline at end of file
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <nv_util.h>
#include <cstdio>
#include <limits>
#include <static_hash_table.hpp>
namespace gpu_cache {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename key_type>
class static_table {
public:
// Ctor
static_table(const size_t table_size, const size_t embedding_vec_size,
const float default_value = 0);
// Dtor
~static_table(){};
// Query API, i.e. A single read from the cache
void Query(const key_type* d_keys, const size_t len, float* d_values, cudaStream_t stream);
// Replace API, i.e. Follow the Query API to update the content of the cache to Most Recent
void Init(const key_type* d_keys, const size_t len, const float* d_values, cudaStream_t stream);
void Clear(cudaStream_t stream);
private:
StaticHashTable<key_type, float> static_hash_table_;
// Embedding vector size
size_t embedding_vec_size_;
size_t table_size_;
float default_value_;
};
} // namespace gpu_cache
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <nv_util.h>
#include <thread>
#include <unordered_map>
#include <vector>
namespace gpu_cache {
template <typename key_type, typename index_type>
class HashBlock {
public:
key_type* keys;
size_t num_sets;
size_t capacity;
HashBlock(size_t expected_capacity, int set_size, int batch_size);
~HashBlock();
void add(const key_type* new_keys, const size_t num_keys, key_type* missing_keys,
int* num_missing_keys, cudaStream_t stream);
void query(const key_type* query_keys, const size_t num_keys, index_type* output_indices,
key_type* missing_keys, int* missing_positions, int* num_missing_keys,
cudaStream_t stream);
void query(const key_type* query_keys, int* num_keys, index_type* output_indices,
cudaStream_t stream);
void clear(cudaStream_t stream);
private:
int max_set_size_;
int batch_size_;
int* set_sizes_;
};
template <typename vec_type>
class H2HCopy {
public:
H2HCopy(int num_threads) : num_threads_(num_threads), working_(num_threads) {
for (int i = 0; i < num_threads_; i++) {
threads_.emplace_back(
[&](int idx) {
while (!terminate_) {
if (working_[idx].load(std::memory_order_relaxed)) {
working_[idx].store(false, std::memory_order_relaxed);
if (num_keys_ == 0) continue;
size_t num_keys_this_thread = (num_keys_ - 1) / num_threads_ + 1;
size_t begin = idx * num_keys_this_thread;
if (idx == num_threads_ - 1) {
num_keys_this_thread = num_keys_ - num_keys_this_thread * idx;
}
size_t end = begin + num_keys_this_thread;
for (size_t i = begin; i < end; i++) {
size_t idx_vec = get_index_(i);
if (idx_vec == std::numeric_limits<size_t>::max()) {
continue;
}
memcpy(dst_data_ptr_ + i * vec_size_, src_data_ptr_ + idx_vec * vec_size_,
sizeof(vec_type) * vec_size_);
}
num_finished_workers_++;
}
}
std::this_thread::sleep_for(std::chrono::microseconds(1));
},
i);
}
};
void copy(vec_type* dst_data_ptr, vec_type* src_data_ptr, size_t num_keys, int vec_size,
std::function<size_t(size_t)> get_index_func) {
std::lock_guard<std::mutex> guard(submit_mutex_);
dst_data_ptr_ = dst_data_ptr;
src_data_ptr_ = src_data_ptr;
get_index_ = get_index_func;
num_keys_ = num_keys;
vec_size_ = vec_size;
num_finished_workers_.store(0, std::memory_order_acquire);
for (auto& working : working_) {
working.store(true, std::memory_order_relaxed);
}
while (num_finished_workers_ != num_threads_) {
continue;
}
}
~H2HCopy() {
terminate_ = true;
for (auto& t : threads_) {
t.join();
}
}
private:
vec_type* src_data_ptr_;
vec_type* dst_data_ptr_;
std::function<size_t(size_t)> get_index_;
size_t num_keys_;
int vec_size_;
std::mutex submit_mutex_;
const int num_threads_;
std::vector<std::thread> threads_;
std::vector<std::atomic<bool>> working_;
volatile bool terminate_{false};
std::atomic<int> num_finished_workers_{0};
};
template <typename key_type, typename index_type, typename vec_type = float>
class UvmTable {
public:
UvmTable(const size_t device_table_capacity, const size_t host_table_capacity,
const int max_batch_size, const int vec_size,
const vec_type default_value = (vec_type)0);
~UvmTable();
void query(const key_type* d_keys, const int len, vec_type* d_vectors, cudaStream_t stream = 0);
void add(const key_type* h_keys, const vec_type* h_vectors, const size_t len);
void clear(cudaStream_t stream = 0);
private:
static constexpr int num_buffers_ = 2;
key_type* d_keys_buffer_;
vec_type* d_vectors_buffer_;
vec_type* d_vectors_;
index_type* d_output_indices_;
index_type* d_output_host_indices_;
index_type* h_output_host_indices_;
key_type* d_missing_keys_;
int* d_missing_positions_;
int* d_missing_count_;
std::vector<vec_type> h_vectors_;
key_type* h_missing_keys_;
cudaStream_t query_stream_;
cudaEvent_t query_event_;
vec_type* h_cpy_buffers_[num_buffers_];
vec_type* d_cpy_buffers_[num_buffers_];
cudaStream_t cpy_streams_[num_buffers_];
cudaEvent_t cpy_events_[num_buffers_];
std::unordered_map<key_type, index_type> h_final_missing_items_;
int max_batch_size_;
int vec_size_;
size_t num_set_;
size_t num_host_set_;
size_t table_capacity_;
std::vector<vec_type> default_vector_;
HashBlock<key_type, index_type> device_table_;
HashBlock<key_type, index_type> host_table_;
};
} // namespace gpu_cache
\ No newline at end of file
#
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cmake_minimum_required(VERSION 3.8)
file(GLOB gpu_cache_src
nv_gpu_cache.cu
static_table.cu
static_hash_table.cu
uvm_table.cu
)
add_library(gpu_cache SHARED ${gpu_cache_src})
target_compile_features(gpu_cache PUBLIC cxx_std_11)
set_target_properties(gpu_cache PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON)
set_target_properties(gpu_cache PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON)
set_target_properties(gpu_cache PROPERTIES CUDA_ARCHITECTURES OFF)
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cooperative_groups.h>
#include <nv_gpu_cache.hpp>
namespace cg = cooperative_groups;
// Overload CUDA atomic for other 64bit unsinged/signed integer type
__forceinline__ __device__ long atomicAdd(long* address, long val) {
return (long)atomicAdd((unsigned long long*)address, (unsigned long long)val);
}
__forceinline__ __device__ long long atomicAdd(long long* address, long long val) {
return (long long)atomicAdd((unsigned long long*)address, (unsigned long long)val);
}
__forceinline__ __device__ unsigned long atomicAdd(unsigned long* address, unsigned long val) {
return (unsigned long)atomicAdd((unsigned long long*)address, (unsigned long long)val);
}
namespace gpu_cache {
#ifdef LIBCUDACXX_VERSION
template <int warp_size>
__forceinline__ __device__ void warp_tile_copy(const size_t lane_idx,
const size_t emb_vec_size_in_float, float* d_dst,
const float* d_src) {
#pragma unroll
for (size_t i = lane_idx; i < emb_vec_size_in_float; i += warp_size) {
d_dst[i] = d_src[i];
}
}
#else
template <int warp_size>
__forceinline__ __device__ void warp_tile_copy(const size_t lane_idx,
const size_t emb_vec_size_in_float,
volatile float* d_dst, volatile float* d_src) {
#pragma unroll
for (size_t i = lane_idx; i < emb_vec_size_in_float; i += warp_size) {
d_dst[i] = d_src[i];
}
}
#endif
#ifdef LIBCUDACXX_VERSION
// Will be called by multiple thread_block_tile((sub-)warp) on the same mutex
// Expect only one thread_block_tile return to execute critical section at any time
template <typename mutex, int warp_size>
__forceinline__ __device__ void warp_lock_mutex(const cg::thread_block_tile<warp_size>& warp_tile,
mutex& set_mutex) {
// The first thread of this (sub-)warp to acquire the lock
if (warp_tile.thread_rank() == 0) {
set_mutex.acquire();
}
warp_tile.sync(); // Synchronize the threads in the (sub-)warp. Execution barrier + memory fence
}
// The (sub-)warp holding the mutex will unlock the mutex after finishing the critical section on a
// set Expect any following (sub-)warp that acquire the mutex can see its modification done in the
// critical section
template <typename mutex, int warp_size>
__forceinline__ __device__ void warp_unlock_mutex(const cg::thread_block_tile<warp_size>& warp_tile,
mutex& set_mutex) {
warp_tile.sync(); // Synchronize the threads in the (sub-)warp. Execution barrier + memory fence
// The first thread of this (sub-)warp to release the lock
if (warp_tile.thread_rank() == 0) {
set_mutex.release();
}
}
#else
// Will be called by multiple thread_block_tile((sub-)warp) on the same mutex
// Expect only one thread_block_tile return to execute critical section at any time
template <int warp_size>
__forceinline__ __device__ void warp_lock_mutex(const cg::thread_block_tile<warp_size>& warp_tile,
volatile int& set_mutex) {
// The first thread of this (sub-)warp to acquire the lock
if (warp_tile.thread_rank() == 0) {
while (0 == atomicCAS((int*)&set_mutex, 1, 0))
;
}
__threadfence();
warp_tile.sync(); // Synchronize the threads in the (sub-)warp. Execution barrier + memory fence
}
// The (sub-)warp holding the mutex will unlock the mutex after finishing the critical section on a
// set Expect any following (sub-)warp that acquire the mutex can see its modification done in the
// critical section
template <int warp_size>
__forceinline__ __device__ void warp_unlock_mutex(const cg::thread_block_tile<warp_size>& warp_tile,
volatile int& set_mutex) {
__threadfence();
warp_tile.sync(); // Synchronize the threads in the (sub-)warp. Execution barrier + memory fence
// The first thread of this (sub-)warp to release the lock
if (warp_tile.thread_rank() == 0) {
atomicExch((int*)&set_mutex, 1);
}
}
#endif
// The (sub-)warp doing all reduction to find the slot with min slot_counter
// The slot with min slot_counter is the LR slot.
template <typename ref_counter_type, int warp_size>
__forceinline__ __device__ void warp_min_reduction(
const cg::thread_block_tile<warp_size>& warp_tile, ref_counter_type& min_slot_counter_val,
size_t& slab_distance, size_t& slot_distance) {
const size_t lane_idx = warp_tile.thread_rank();
slot_distance = lane_idx;
for (size_t i = (warp_tile.size() >> 1); i > 0; i = i >> 1) {
ref_counter_type input_slot_counter_val = warp_tile.shfl_xor(min_slot_counter_val, (int)i);
size_t input_slab_distance = warp_tile.shfl_xor(slab_distance, (int)i);
size_t input_slot_distance = warp_tile.shfl_xor(slot_distance, (int)i);
if (input_slot_counter_val == min_slot_counter_val) {
if (input_slab_distance == slab_distance) {
if (input_slot_distance < slot_distance) {
slot_distance = input_slot_distance;
}
} else if (input_slab_distance < slab_distance) {
slab_distance = input_slab_distance;
slot_distance = input_slot_distance;
}
} else if (input_slot_counter_val < min_slot_counter_val) {
min_slot_counter_val = input_slot_counter_val;
slab_distance = input_slab_distance;
slot_distance = input_slot_distance;
}
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef LIBCUDACXX_VERSION
// Kernel to initialize the GPU cache
// Init every entry of the cache with <unused_key, value> pair
template <typename slabset, typename ref_counter_type, typename atomic_ref_counter_type,
typename key_type, typename mutex>
__global__ void init_cache(slabset* keys, ref_counter_type* slot_counter,
atomic_ref_counter_type* global_counter, const size_t num_slot,
const key_type empty_key, mutex* set_mutex,
const size_t capacity_in_set) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_slot) {
// Set the key of this slot to unused key
// Flatten the cache
key_type* key_slot = (key_type*)keys;
key_slot[idx] = empty_key;
// Clear the counter for this slot
slot_counter[idx] = 0;
}
// First CUDA thread clear the global counter
if (idx == 0) {
new (global_counter) atomic_ref_counter_type(0);
}
// First capacity_in_set CUDA thread initialize mutex
if (idx < capacity_in_set) {
new (set_mutex + idx) mutex(1);
}
}
template <typename atomic_ref_counter_type, typename mutex>
__global__ void destruct_kernel(atomic_ref_counter_type* global_counter, mutex* set_mutex,
const size_t capacity_in_set) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
// First CUDA thread destruct the global_counter
if (idx == 0) {
global_counter->~atomic_ref_counter_type();
}
// First capacity_in_set CUDA thread destruct the set mutex
if (idx < capacity_in_set) {
(set_mutex + idx)->~mutex();
}
}
#else
// Kernel to initialize the GPU cache
// Init every entry of the cache with <unused_key, value> pair
template <typename slabset, typename ref_counter_type, typename key_type>
__global__ void init_cache(slabset* keys, ref_counter_type* slot_counter,
ref_counter_type* global_counter, const size_t num_slot,
const key_type empty_key, int* set_mutex, const size_t capacity_in_set) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_slot) {
// Set the key of this slot to unused key
// Flatten the cache
key_type* key_slot = (key_type*)keys;
key_slot[idx] = empty_key;
// Clear the counter for this slot
slot_counter[idx] = 0;
}
// First CUDA thread clear the global counter
if (idx == 0) {
global_counter[idx] = 0;
}
// First capacity_in_set CUDA thread initialize mutex
if (idx < capacity_in_set) {
set_mutex[idx] = 1;
}
}
#endif
// Kernel to update global counter
// Resolve distance overflow issue as well
#ifdef LIBCUDACXX_VERSION
template <typename atomic_ref_counter_type>
__global__ void update_kernel_overflow_ignore(atomic_ref_counter_type* global_counter,
size_t* d_missing_len) {
// Update global counter
global_counter->fetch_add(1, cuda::std::memory_order_relaxed);
*d_missing_len = 0;
}
#else
template <typename ref_counter_type>
__global__ void update_kernel_overflow_ignore(ref_counter_type* global_counter,
size_t* d_missing_len) {
// Update global counter
atomicAdd(global_counter, 1);
*d_missing_len = 0;
}
#endif
#ifdef LIBCUDACXX_VERSION
// Kernel to read from cache
// Also update locality information for touched slot
template <typename key_type, typename ref_counter_type, typename atomic_ref_counter_type,
typename slabset, typename set_hasher, typename slab_hasher, typename mutex,
key_type empty_key, int set_associativity, int warp_size>
__global__ void get_kernel(const key_type* d_keys, const size_t len, float* d_values,
const size_t embedding_vec_size, uint64_t* d_missing_index,
key_type* d_missing_keys, size_t* d_missing_len,
const atomic_ref_counter_type* global_counter,
ref_counter_type* slot_counter, const size_t capacity_in_set,
const slabset* keys, const float* vals, mutex* set_mutex,
const size_t task_per_warp_tile) {
// Lane(thread) ID within a warp_tile
cg::thread_block_tile<warp_size> warp_tile =
cg::tiled_partition<warp_size>(cg::this_thread_block());
const size_t lane_idx = warp_tile.thread_rank();
// Warp tile global ID
const size_t warp_tile_global_idx =
(blockIdx.x * (blockDim.x / warp_size)) + warp_tile.meta_group_rank();
// The index of key for this thread
const size_t key_idx = (warp_tile_global_idx * task_per_warp_tile) + lane_idx;
// The assigned key for this lane(thread)
key_type key;
// The dst slabset and the dst slab inside this set
size_t src_set;
size_t src_slab;
// The variable that contains the missing key
key_type missing_key;
// The variable that contains the index for the missing key
uint64_t missing_index;
// The counter for counting the missing key in this warp
uint8_t warp_missing_counter = 0;
// Active flag: whether current lane(thread) has unfinished task
bool active = false;
if (lane_idx < task_per_warp_tile) {
if (key_idx < len) {
active = true;
key = d_keys[key_idx];
src_set = set_hasher::hash(key) % capacity_in_set;
src_slab = slab_hasher::hash(key) % set_associativity;
}
}
// Lane participate in warp_tile ballot to produce warp-level work queue
unsigned active_mask = warp_tile.ballot(active);
// The warp-level outer loop: finish all the tasks within the work queue
while (active_mask != 0) {
// Next task in the work quere, start from lower index lane(thread)
int next_lane = __ffs(active_mask) - 1;
// Broadcast the task and the global index to all lane in the warp_tile
key_type next_key = warp_tile.shfl(key, next_lane);
size_t next_idx = warp_tile.shfl(key_idx, next_lane);
size_t next_set = warp_tile.shfl(src_set, next_lane);
size_t next_slab = warp_tile.shfl(src_slab, next_lane);
// Counter to record how many slab have been searched
size_t counter = 0;
// Working queue before task started
const unsigned old_active_mask = active_mask;
// Lock the slabset before operating the slabset
warp_lock_mutex<mutex, warp_size>(warp_tile, set_mutex[next_set]);
// The warp-level inner loop: finish a single task in the work queue
while (active_mask == old_active_mask) {
// When all the slabs inside a slabset have been searched, mark missing task, task is
// completed
if (counter >= set_associativity) {
if (lane_idx == warp_missing_counter) {
missing_key = next_key;
missing_index = next_idx;
}
if (lane_idx == (size_t)next_lane) {
active = false;
}
warp_missing_counter++;
active_mask = warp_tile.ballot(active);
break;
}
// The warp_tile read out the slab
key_type read_key = keys[next_set].set_[next_slab].slab_[lane_idx];
// Compare the slab data with the target key
int found_lane = __ffs(warp_tile.ballot(read_key == next_key)) - 1;
// If found, mark hit task, copy the founded data, the task is completed
if (found_lane >= 0) {
size_t found_offset = (next_set * set_associativity + next_slab) * warp_size + found_lane;
if (lane_idx == (size_t)next_lane) {
slot_counter[found_offset] = global_counter->load(cuda::std::memory_order_relaxed);
active = false;
}
warp_tile_copy<warp_size>(lane_idx, embedding_vec_size,
d_values + next_idx * embedding_vec_size,
vals + found_offset * embedding_vec_size);
active_mask = warp_tile.ballot(active);
break;
}
// Compare the slab data with empty key, if found empty key, mark missing task, task is
// completed
if (warp_tile.ballot(read_key == empty_key) != 0) {
if (lane_idx == warp_missing_counter) {
missing_key = next_key;
missing_index = next_idx;
}
if (lane_idx == (size_t)next_lane) {
active = false;
}
warp_missing_counter++;
active_mask = warp_tile.ballot(active);
break;
}
// Not found in this slab, the task is not completed, goto searching next slab
counter++;
next_slab = (next_slab + 1) % set_associativity;
}
// Unlock the slabset after operating the slabset
warp_unlock_mutex<mutex, warp_size>(warp_tile, set_mutex[next_set]);
}
// After warp_tile complete the working queue, save the result for output
// First thread of the warp_tile accumulate the missing length to global variable
size_t warp_position;
if (lane_idx == 0) {
warp_position = atomicAdd(d_missing_len, (size_t)warp_missing_counter);
}
warp_position = warp_tile.shfl(warp_position, 0);
if (lane_idx < warp_missing_counter) {
d_missing_keys[warp_position + lane_idx] = missing_key;
d_missing_index[warp_position + lane_idx] = missing_index;
}
}
#else
// Kernel to read from cache
// Also update locality information for touched slot
template <typename key_type, typename ref_counter_type, typename slabset, typename set_hasher,
typename slab_hasher, key_type empty_key, int set_associativity, int warp_size>
__global__ void get_kernel(const key_type* d_keys, const size_t len, float* d_values,
const size_t embedding_vec_size, uint64_t* d_missing_index,
key_type* d_missing_keys, size_t* d_missing_len,
ref_counter_type* global_counter,
volatile ref_counter_type* slot_counter, const size_t capacity_in_set,
volatile slabset* keys, volatile float* vals, volatile int* set_mutex,
const size_t task_per_warp_tile) {
// Lane(thread) ID within a warp_tile
cg::thread_block_tile<warp_size> warp_tile =
cg::tiled_partition<warp_size>(cg::this_thread_block());
const size_t lane_idx = warp_tile.thread_rank();
// Warp tile global ID
const size_t warp_tile_global_idx =
(blockIdx.x * (blockDim.x / warp_size)) + warp_tile.meta_group_rank();
// The index of key for this thread
const size_t key_idx = (warp_tile_global_idx * task_per_warp_tile) + lane_idx;
// The assigned key for this lane(thread)
key_type key;
// The dst slabset and the dst slab inside this set
size_t src_set;
size_t src_slab;
// The variable that contains the missing key
key_type missing_key;
// The variable that contains the index for the missing key
uint64_t missing_index;
// The counter for counting the missing key in this warp
uint8_t warp_missing_counter = 0;
// Active flag: whether current lane(thread) has unfinished task
bool active = false;
if (lane_idx < task_per_warp_tile) {
if (key_idx < len) {
active = true;
key = d_keys[key_idx];
src_set = set_hasher::hash(key) % capacity_in_set;
src_slab = slab_hasher::hash(key) % set_associativity;
}
}
// Lane participate in warp_tile ballot to produce warp-level work queue
unsigned active_mask = warp_tile.ballot(active);
// The warp-level outer loop: finish all the tasks within the work queue
while (active_mask != 0) {
// Next task in the work quere, start from lower index lane(thread)
int next_lane = __ffs(active_mask) - 1;
// Broadcast the task and the global index to all lane in the warp_tile
key_type next_key = warp_tile.shfl(key, next_lane);
size_t next_idx = warp_tile.shfl(key_idx, next_lane);
size_t next_set = warp_tile.shfl(src_set, next_lane);
size_t next_slab = warp_tile.shfl(src_slab, next_lane);
// Counter to record how many slab have been searched
size_t counter = 0;
// Working queue before task started
const unsigned old_active_mask = active_mask;
// Lock the slabset before operating the slabset
warp_lock_mutex<warp_size>(warp_tile, set_mutex[next_set]);
// The warp-level inner loop: finish a single task in the work queue
while (active_mask == old_active_mask) {
// When all the slabs inside a slabset have been searched, mark missing task, task is
// completed
if (counter >= set_associativity) {
if (lane_idx == warp_missing_counter) {
missing_key = next_key;
missing_index = next_idx;
}
if (lane_idx == (size_t)next_lane) {
active = false;
}
warp_missing_counter++;
active_mask = warp_tile.ballot(active);
break;
}
// The warp_tile read out the slab
key_type read_key = ((volatile key_type*)(keys[next_set].set_[next_slab].slab_))[lane_idx];
// Compare the slab data with the target key
int found_lane = __ffs(warp_tile.ballot(read_key == next_key)) - 1;
// If found, mark hit task, copy the founded data, the task is completed
if (found_lane >= 0) {
size_t found_offset = (next_set * set_associativity + next_slab) * warp_size + found_lane;
if (lane_idx == (size_t)next_lane) {
slot_counter[found_offset] = atomicAdd(global_counter, 0);
active = false;
}
warp_tile_copy<warp_size>(lane_idx, embedding_vec_size,
(volatile float*)(d_values + next_idx * embedding_vec_size),
(volatile float*)(vals + found_offset * embedding_vec_size));
active_mask = warp_tile.ballot(active);
break;
}
// Compare the slab data with empty key, if found empty key, mark missing task, task is
// completed
if (warp_tile.ballot(read_key == empty_key) != 0) {
if (lane_idx == warp_missing_counter) {
missing_key = next_key;
missing_index = next_idx;
}
if (lane_idx == (size_t)next_lane) {
active = false;
}
warp_missing_counter++;
active_mask = warp_tile.ballot(active);
break;
}
// Not found in this slab, the task is not completed, goto searching next slab
counter++;
next_slab = (next_slab + 1) % set_associativity;
}
// Unlock the slabset after operating the slabset
warp_unlock_mutex<warp_size>(warp_tile, set_mutex[next_set]);
}
// After warp_tile complete the working queue, save the result for output
// First thread of the warp_tile accumulate the missing length to global variable
size_t warp_position;
if (lane_idx == 0) {
warp_position = atomicAdd(d_missing_len, (size_t)warp_missing_counter);
}
warp_position = warp_tile.shfl(warp_position, 0);
if (lane_idx < warp_missing_counter) {
d_missing_keys[warp_position + lane_idx] = missing_key;
d_missing_index[warp_position + lane_idx] = missing_index;
}
}
#endif
#ifdef LIBCUDACXX_VERSION
// Kernel to insert or replace the <k,v> pairs into the cache
template <typename key_type, typename slabset, typename ref_counter_type, typename mutex,
typename atomic_ref_counter_type, typename set_hasher, typename slab_hasher,
key_type empty_key, int set_associativity, int warp_size,
ref_counter_type max_ref_counter_type = std::numeric_limits<ref_counter_type>::max(),
size_t max_slab_distance = std::numeric_limits<size_t>::max()>
__global__ void insert_replace_kernel(const key_type* d_keys, const float* d_values,
const size_t embedding_vec_size, const size_t len,
slabset* keys, float* vals, ref_counter_type* slot_counter,
mutex* set_mutex,
const atomic_ref_counter_type* global_counter,
const size_t capacity_in_set,
const size_t task_per_warp_tile) {
// Lane(thread) ID within a warp_tile
cg::thread_block_tile<warp_size> warp_tile =
cg::tiled_partition<warp_size>(cg::this_thread_block());
const size_t lane_idx = warp_tile.thread_rank();
// Warp tile global ID
const size_t warp_tile_global_idx =
(blockIdx.x * (blockDim.x / warp_size)) + warp_tile.meta_group_rank();
// The index of key for this thread
const size_t key_idx = (warp_tile_global_idx * task_per_warp_tile) + lane_idx;
// The assigned key for this lane(thread)
key_type key;
// The dst slabset and the dst slab inside this set
size_t src_set;
size_t src_slab;
// Active flag: whether current lane(thread) has unfinished task
bool active = false;
if (lane_idx < task_per_warp_tile) {
if (key_idx < len) {
active = true;
key = d_keys[key_idx];
src_set = set_hasher::hash(key) % capacity_in_set;
src_slab = slab_hasher::hash(key) % set_associativity;
}
}
// Lane participate in warp_tile ballot to produce warp-level work queue
unsigned active_mask = warp_tile.ballot(active);
// The warp-level outer loop: finish all the tasks within the work queue
while (active_mask != 0) {
// Next task in the work quere, start from lower index lane(thread)
int next_lane = __ffs(active_mask) - 1;
// Broadcast the task, the global index and the src slabset and slab to all lane in a warp_tile
key_type next_key = warp_tile.shfl(key, next_lane);
size_t next_idx = warp_tile.shfl(key_idx, next_lane);
size_t next_set = warp_tile.shfl(src_set, next_lane);
size_t next_slab = warp_tile.shfl(src_slab, next_lane);
size_t first_slab = next_slab;
// Counter to record how many slab have been searched
size_t counter = 0;
// Variable to keep the min slot counter during the probing
ref_counter_type min_slot_counter_val = max_ref_counter_type;
// Variable to keep the slab distance for slot with min counter
size_t slab_distance = max_slab_distance;
// Variable to keep the slot distance for slot with min counter within the slab
size_t slot_distance;
// Working queue before task started
const unsigned old_active_mask = active_mask;
// Lock the slabset before operating the slabset
warp_lock_mutex<mutex, warp_size>(warp_tile, set_mutex[next_set]);
// The warp-level inner loop: finish a single task in the work queue
while (active_mask == old_active_mask) {
// When all the slabs inside a slabset have been searched
// and no empty slots or target slots are found. Replace with LRU
if (counter >= set_associativity) {
// (sub)Warp all-reduction, the reduction result store in all threads
warp_min_reduction<ref_counter_type, warp_size>(warp_tile, min_slot_counter_val,
slab_distance, slot_distance);
// Calculate the position of LR slot
size_t target_slab = (first_slab + slab_distance) % set_associativity;
size_t slot_index =
(next_set * set_associativity + target_slab) * warp_size + slot_distance;
// Replace the LR slot
if (lane_idx == (size_t)next_lane) {
keys[next_set].set_[target_slab].slab_[slot_distance] = key;
slot_counter[slot_index] = global_counter->load(cuda::std::memory_order_relaxed);
}
warp_tile_copy<warp_size>(lane_idx, embedding_vec_size,
vals + slot_index * embedding_vec_size,
d_values + next_idx * embedding_vec_size);
// Replace complete, mark this task completed
if (lane_idx == (size_t)next_lane) {
active = false;
}
active_mask = warp_tile.ballot(active);
break;
}
// The warp_tile read out the slab
key_type read_key = keys[next_set].set_[next_slab].slab_[lane_idx];
// Compare the slab data with the target key
int found_lane = __ffs(warp_tile.ballot(read_key == next_key)) - 1;
// If found target key, the insertion/replace is no longer needed.
// Refresh the slot, the task is completed
if (found_lane >= 0) {
size_t found_offset = (next_set * set_associativity + next_slab) * warp_size + found_lane;
if (lane_idx == (size_t)next_lane) {
slot_counter[found_offset] = global_counter->load(cuda::std::memory_order_relaxed);
active = false;
}
active_mask = warp_tile.ballot(active);
break;
}
// Compare the slab data with empty key.
// If found empty key, do insertion,the task is complete
found_lane = __ffs(warp_tile.ballot(read_key == empty_key)) - 1;
if (found_lane >= 0) {
size_t found_offset = (next_set * set_associativity + next_slab) * warp_size + found_lane;
if (lane_idx == (size_t)next_lane) {
keys[next_set].set_[next_slab].slab_[found_lane] = key;
slot_counter[found_offset] = global_counter->load(cuda::std::memory_order_relaxed);
}
warp_tile_copy<warp_size>(lane_idx, embedding_vec_size,
vals + found_offset * embedding_vec_size,
d_values + next_idx * embedding_vec_size);
if (lane_idx == (size_t)next_lane) {
active = false;
}
active_mask = warp_tile.ballot(active);
break;
}
// If no target or unused slot found in this slab,
// Refresh LR info, continue probing
ref_counter_type read_slot_counter =
slot_counter[(next_set * set_associativity + next_slab) * warp_size + lane_idx];
if (read_slot_counter < min_slot_counter_val) {
min_slot_counter_val = read_slot_counter;
slab_distance = counter;
}
counter++;
next_slab = (next_slab + 1) % set_associativity;
}
// Unlock the slabset after operating the slabset
warp_unlock_mutex<mutex, warp_size>(warp_tile, set_mutex[next_set]);
}
}
#else
// Kernel to insert or replace the <k,v> pairs into the cache
template <typename key_type, typename slabset, typename ref_counter_type, typename set_hasher,
typename slab_hasher, key_type empty_key, int set_associativity, int warp_size,
ref_counter_type max_ref_counter_type = std::numeric_limits<ref_counter_type>::max(),
size_t max_slab_distance = std::numeric_limits<size_t>::max()>
__global__ void insert_replace_kernel(const key_type* d_keys, const float* d_values,
const size_t embedding_vec_size, const size_t len,
volatile slabset* keys, volatile float* vals,
volatile ref_counter_type* slot_counter,
volatile int* set_mutex, ref_counter_type* global_counter,
const size_t capacity_in_set,
const size_t task_per_warp_tile) {
// Lane(thread) ID within a warp_tile
cg::thread_block_tile<warp_size> warp_tile =
cg::tiled_partition<warp_size>(cg::this_thread_block());
const size_t lane_idx = warp_tile.thread_rank();
// Warp tile global ID
const size_t warp_tile_global_idx =
(blockIdx.x * (blockDim.x / warp_size)) + warp_tile.meta_group_rank();
// The index of key for this thread
const size_t key_idx = (warp_tile_global_idx * task_per_warp_tile) + lane_idx;
// The assigned key for this lane(thread)
key_type key;
// The dst slabset and the dst slab inside this set
size_t src_set;
size_t src_slab;
// Active flag: whether current lane(thread) has unfinished task
bool active = false;
if (lane_idx < task_per_warp_tile) {
if (key_idx < len) {
active = true;
key = d_keys[key_idx];
src_set = set_hasher::hash(key) % capacity_in_set;
src_slab = slab_hasher::hash(key) % set_associativity;
}
}
// Lane participate in warp_tile ballot to produce warp-level work queue
unsigned active_mask = warp_tile.ballot(active);
// The warp-level outer loop: finish all the tasks within the work queue
while (active_mask != 0) {
// Next task in the work quere, start from lower index lane(thread)
int next_lane = __ffs(active_mask) - 1;
// Broadcast the task, the global index and the src slabset and slab to all lane in a warp_tile
key_type next_key = warp_tile.shfl(key, next_lane);
size_t next_idx = warp_tile.shfl(key_idx, next_lane);
size_t next_set = warp_tile.shfl(src_set, next_lane);
size_t next_slab = warp_tile.shfl(src_slab, next_lane);
size_t first_slab = next_slab;
// Counter to record how many slab have been searched
size_t counter = 0;
// Variable to keep the min slot counter during the probing
ref_counter_type min_slot_counter_val = max_ref_counter_type;
// Variable to keep the slab distance for slot with min counter
size_t slab_distance = max_slab_distance;
// Variable to keep the slot distance for slot with min counter within the slab
size_t slot_distance;
// Working queue before task started
const unsigned old_active_mask = active_mask;
// Lock the slabset before operating the slabset
warp_lock_mutex<warp_size>(warp_tile, set_mutex[next_set]);
// The warp-level inner loop: finish a single task in the work queue
while (active_mask == old_active_mask) {
// When all the slabs inside a slabset have been searched
// and no empty slots or target slots are found. Replace with LRU
if (counter >= set_associativity) {
// (sub)Warp all-reduction, the reduction result store in all threads
warp_min_reduction<ref_counter_type, warp_size>(warp_tile, min_slot_counter_val,
slab_distance, slot_distance);
// Calculate the position of LR slot
size_t target_slab = (first_slab + slab_distance) % set_associativity;
size_t slot_index =
(next_set * set_associativity + target_slab) * warp_size + slot_distance;
// Replace the LR slot
if (lane_idx == (size_t)next_lane) {
((volatile key_type*)(keys[next_set].set_[target_slab].slab_))[slot_distance] = key;
slot_counter[slot_index] = atomicAdd(global_counter, 0);
}
warp_tile_copy<warp_size>(lane_idx, embedding_vec_size,
(volatile float*)(vals + slot_index * embedding_vec_size),
(volatile float*)(d_values + next_idx * embedding_vec_size));
// Replace complete, mark this task completed
if (lane_idx == (size_t)next_lane) {
active = false;
}
active_mask = warp_tile.ballot(active);
break;
}
// The warp_tile read out the slab
key_type read_key = ((volatile key_type*)(keys[next_set].set_[next_slab].slab_))[lane_idx];
// Compare the slab data with the target key
int found_lane = __ffs(warp_tile.ballot(read_key == next_key)) - 1;
// If found target key, the insertion/replace is no longer needed.
// Refresh the slot, the task is completed
if (found_lane >= 0) {
size_t found_offset = (next_set * set_associativity + next_slab) * warp_size + found_lane;
if (lane_idx == (size_t)next_lane) {
slot_counter[found_offset] = atomicAdd(global_counter, 0);
active = false;
}
active_mask = warp_tile.ballot(active);
break;
}
// Compare the slab data with empty key.
// If found empty key, do insertion,the task is complete
found_lane = __ffs(warp_tile.ballot(read_key == empty_key)) - 1;
if (found_lane >= 0) {
size_t found_offset = (next_set * set_associativity + next_slab) * warp_size + found_lane;
if (lane_idx == (size_t)next_lane) {
((volatile key_type*)(keys[next_set].set_[next_slab].slab_))[found_lane] = key;
slot_counter[found_offset] = atomicAdd(global_counter, 0);
}
warp_tile_copy<warp_size>(lane_idx, embedding_vec_size,
(volatile float*)(vals + found_offset * embedding_vec_size),
(volatile float*)(d_values + next_idx * embedding_vec_size));
if (lane_idx == (size_t)next_lane) {
active = false;
}
active_mask = warp_tile.ballot(active);
break;
}
// If no target or unused slot found in this slab,
// Refresh LR info, continue probing
ref_counter_type read_slot_counter =
slot_counter[(next_set * set_associativity + next_slab) * warp_size + lane_idx];
if (read_slot_counter < min_slot_counter_val) {
min_slot_counter_val = read_slot_counter;
slab_distance = counter;
}
counter++;
next_slab = (next_slab + 1) % set_associativity;
}
// Unlock the slabset after operating the slabset
warp_unlock_mutex<warp_size>(warp_tile, set_mutex[next_set]);
}
}
#endif
#ifdef LIBCUDACXX_VERSION
// Kernel to update the existing keys in the cache
// Will not change the locality information
template <typename key_type, typename slabset, typename set_hasher, typename slab_hasher,
typename mutex, key_type empty_key, int set_associativity, int warp_size>
__global__ void update_kernel(const key_type* d_keys, const size_t len, const float* d_values,
const size_t embedding_vec_size, const size_t capacity_in_set,
const slabset* keys, float* vals, mutex* set_mutex,
const size_t task_per_warp_tile) {
// Lane(thread) ID within a warp_tile
cg::thread_block_tile<warp_size> warp_tile =
cg::tiled_partition<warp_size>(cg::this_thread_block());
const size_t lane_idx = warp_tile.thread_rank();
// Warp tile global ID
const size_t warp_tile_global_idx =
(blockIdx.x * (blockDim.x / warp_size)) + warp_tile.meta_group_rank();
// The index of key for this thread
const size_t key_idx = (warp_tile_global_idx * task_per_warp_tile) + lane_idx;
// The assigned key for this lane(thread)
key_type key;
// The dst slabset and the dst slab inside this set
size_t src_set;
size_t src_slab;
// Active flag: whether current lane(thread) has unfinished task
bool active = false;
if (lane_idx < task_per_warp_tile) {
if (key_idx < len) {
active = true;
key = d_keys[key_idx];
src_set = set_hasher::hash(key) % capacity_in_set;
src_slab = slab_hasher::hash(key) % set_associativity;
}
}
// Lane participate in warp_tile ballot to produce warp-level work queue
unsigned active_mask = warp_tile.ballot(active);
// The warp-level outer loop: finish all the tasks within the work queue
while (active_mask != 0) {
// Next task in the work quere, start from lower index lane(thread)
int next_lane = __ffs(active_mask) - 1;
// Broadcast the task and the global index to all lane in the warp_tile
key_type next_key = warp_tile.shfl(key, next_lane);
size_t next_idx = warp_tile.shfl(key_idx, next_lane);
size_t next_set = warp_tile.shfl(src_set, next_lane);
size_t next_slab = warp_tile.shfl(src_slab, next_lane);
// Counter to record how many slab have been searched
size_t counter = 0;
// Working queue before task started
const unsigned old_active_mask = active_mask;
// Lock the slabset before operating the slabset
warp_lock_mutex<mutex, warp_size>(warp_tile, set_mutex[next_set]);
// The warp-level inner loop: finish a single task in the work queue
while (active_mask == old_active_mask) {
// When all the slabs inside a slabset have been searched, mark missing task, do nothing, task
// complete
if (counter >= set_associativity) {
if (lane_idx == (size_t)next_lane) {
active = false;
}
active_mask = warp_tile.ballot(active);
break;
}
// The warp_tile read out the slab
key_type read_key = keys[next_set].set_[next_slab].slab_[lane_idx];
// Compare the slab data with the target key
int found_lane = __ffs(warp_tile.ballot(read_key == next_key)) - 1;
// If found, mark hit task, update the value, the task is completed
if (found_lane >= 0) {
size_t found_offset = (next_set * set_associativity + next_slab) * warp_size + found_lane;
if (lane_idx == (size_t)next_lane) {
active = false;
}
warp_tile_copy<warp_size>(lane_idx, embedding_vec_size,
vals + found_offset * embedding_vec_size,
d_values + next_idx * embedding_vec_size);
active_mask = warp_tile.ballot(active);
break;
}
// Compare the slab data with empty key, if found empty key, mark missing task, do nothing,
// task is completed
if (warp_tile.ballot(read_key == empty_key) != 0) {
if (lane_idx == (size_t)next_lane) {
active = false;
}
active_mask = warp_tile.ballot(active);
break;
}
// Not found in this slab, the task is not completed, goto searching next slab
counter++;
next_slab = (next_slab + 1) % set_associativity;
}
// Unlock the slabset after operating the slabset
warp_unlock_mutex<mutex, warp_size>(warp_tile, set_mutex[next_set]);
}
}
#else
// Kernel to update the existing keys in the cache
// Will not change the locality information
template <typename key_type, typename slabset, typename set_hasher, typename slab_hasher,
key_type empty_key, int set_associativity, int warp_size>
__global__ void update_kernel(const key_type* d_keys, const size_t len, const float* d_values,
const size_t embedding_vec_size, const size_t capacity_in_set,
volatile slabset* keys, volatile float* vals, volatile int* set_mutex,
const size_t task_per_warp_tile) {
// Lane(thread) ID within a warp_tile
cg::thread_block_tile<warp_size> warp_tile =
cg::tiled_partition<warp_size>(cg::this_thread_block());
const size_t lane_idx = warp_tile.thread_rank();
// Warp tile global ID
const size_t warp_tile_global_idx =
(blockIdx.x * (blockDim.x / warp_size)) + warp_tile.meta_group_rank();
// The index of key for this thread
const size_t key_idx = (warp_tile_global_idx * task_per_warp_tile) + lane_idx;
// The assigned key for this lane(thread)
key_type key;
// The dst slabset and the dst slab inside this set
size_t src_set;
size_t src_slab;
// Active flag: whether current lane(thread) has unfinished task
bool active = false;
if (lane_idx < task_per_warp_tile) {
if (key_idx < len) {
active = true;
key = d_keys[key_idx];
src_set = set_hasher::hash(key) % capacity_in_set;
src_slab = slab_hasher::hash(key) % set_associativity;
}
}
// Lane participate in warp_tile ballot to produce warp-level work queue
unsigned active_mask = warp_tile.ballot(active);
// The warp-level outer loop: finish all the tasks within the work queue
while (active_mask != 0) {
// Next task in the work quere, start from lower index lane(thread)
int next_lane = __ffs(active_mask) - 1;
// Broadcast the task and the global index to all lane in the warp_tile
key_type next_key = warp_tile.shfl(key, next_lane);
size_t next_idx = warp_tile.shfl(key_idx, next_lane);
size_t next_set = warp_tile.shfl(src_set, next_lane);
size_t next_slab = warp_tile.shfl(src_slab, next_lane);
// Counter to record how many slab have been searched
size_t counter = 0;
// Working queue before task started
const unsigned old_active_mask = active_mask;
// Lock the slabset before operating the slabset
warp_lock_mutex<warp_size>(warp_tile, set_mutex[next_set]);
// The warp-level inner loop: finish a single task in the work queue
while (active_mask == old_active_mask) {
// When all the slabs inside a slabset have been searched, mark missing task, do nothing, task
// complete
if (counter >= set_associativity) {
if (lane_idx == (size_t)next_lane) {
active = false;
}
active_mask = warp_tile.ballot(active);
break;
}
// The warp_tile read out the slab
key_type read_key = ((volatile key_type*)(keys[next_set].set_[next_slab].slab_))[lane_idx];
// Compare the slab data with the target key
int found_lane = __ffs(warp_tile.ballot(read_key == next_key)) - 1;
// If found, mark hit task, update the value, the task is completed
if (found_lane >= 0) {
size_t found_offset = (next_set * set_associativity + next_slab) * warp_size + found_lane;
if (lane_idx == (size_t)next_lane) {
active = false;
}
warp_tile_copy<warp_size>(lane_idx, embedding_vec_size,
(volatile float*)(vals + found_offset * embedding_vec_size),
(volatile float*)(d_values + next_idx * embedding_vec_size));
active_mask = warp_tile.ballot(active);
break;
}
// Compare the slab data with empty key, if found empty key, mark missing task, do nothing,
// task is completed
if (warp_tile.ballot(read_key == empty_key) != 0) {
if (lane_idx == (size_t)next_lane) {
active = false;
}
active_mask = warp_tile.ballot(active);
break;
}
// Not found in this slab, the task is not completed, goto searching next slab
counter++;
next_slab = (next_slab + 1) % set_associativity;
}
// Unlock the slabset after operating the slabset
warp_unlock_mutex<warp_size>(warp_tile, set_mutex[next_set]);
}
}
#endif
#ifdef LIBCUDACXX_VERSION
template <typename key_type, typename slabset, typename mutex, key_type empty_key,
int set_associativity, int warp_size>
__global__ void dump_kernel(key_type* d_keys, size_t* d_dump_counter, const slabset* keys,
mutex* set_mutex, const size_t start_set_index,
const size_t end_set_index) {
// Block-level counter used by all warp tiles within a block
__shared__ uint32_t block_acc;
// Initialize block-level counter
if (threadIdx.x == 0) {
block_acc = 0;
}
__syncthreads();
// Lane(thread) ID within a warp tile
cg::thread_block_tile<warp_size> warp_tile =
cg::tiled_partition<warp_size>(cg::this_thread_block());
const size_t lane_idx = warp_tile.thread_rank();
// Warp tile target slabset id
const size_t set_idx =
((blockIdx.x * (blockDim.x / warp_size)) + warp_tile.meta_group_rank()) + start_set_index;
// Keys dump from cache
key_type read_key[set_associativity];
// Lane(thread) offset for storing each key
uint32_t thread_key_offset[set_associativity];
// Warp offset for storing each key
uint32_t warp_key_offset;
// Block offset for storing each key
__shared__ size_t block_key_offset;
// Warp tile dump target slabset
if (set_idx < end_set_index) {
// Lock the slabset before operating the slabset
warp_lock_mutex<mutex, warp_size>(warp_tile, set_mutex[set_idx]);
// The warp tile read out the slabset
for (unsigned slab_id = 0; slab_id < set_associativity; slab_id++) {
// The warp tile read out a slab
read_key[slab_id] = keys[set_idx].set_[slab_id].slab_[lane_idx];
}
// Finish dumping the slabset, unlock the slabset
warp_unlock_mutex<mutex, warp_size>(warp_tile, set_mutex[set_idx]);
// Each lane(thread) within the warp tile calculate the offset to store its keys
uint32_t warp_tile_total_keys = 0;
for (unsigned slab_id = 0; slab_id < set_associativity; slab_id++) {
unsigned valid_mask = warp_tile.ballot(read_key[slab_id] != empty_key);
thread_key_offset[slab_id] =
__popc(valid_mask & ((1U << lane_idx) - 1U)) + warp_tile_total_keys;
warp_tile_total_keys = warp_tile_total_keys + __popc(valid_mask);
}
// Each warp tile request a unique place from the block-level counter
if (lane_idx == 0) {
warp_key_offset = atomicAdd(&block_acc, warp_tile_total_keys);
}
warp_key_offset = warp_tile.shfl(warp_key_offset, 0);
}
// Each block request a unique place in global memory output buffer
__syncthreads();
if (threadIdx.x == 0) {
block_key_offset = atomicAdd(d_dump_counter, (size_t)block_acc);
}
__syncthreads();
// Warp tile store the (non-empty)keys back to output buffer
if (set_idx < end_set_index) {
for (unsigned slab_id = 0; slab_id < set_associativity; slab_id++) {
if (read_key[slab_id] != empty_key) {
d_keys[block_key_offset + warp_key_offset + thread_key_offset[slab_id]] = read_key[slab_id];
}
}
}
}
#else
template <typename key_type, typename slabset, key_type empty_key, int set_associativity,
int warp_size>
__global__ void dump_kernel(key_type* d_keys, size_t* d_dump_counter, volatile slabset* keys,
volatile int* set_mutex, const size_t start_set_index,
const size_t end_set_index) {
// Block-level counter used by all warp tiles within a block
__shared__ uint32_t block_acc;
// Initialize block-level counter
if (threadIdx.x == 0) {
block_acc = 0;
}
__syncthreads();
// Lane(thread) ID within a warp tile
cg::thread_block_tile<warp_size> warp_tile =
cg::tiled_partition<warp_size>(cg::this_thread_block());
const size_t lane_idx = warp_tile.thread_rank();
// Warp tile target slabset id
const size_t set_idx =
((blockIdx.x * (blockDim.x / warp_size)) + warp_tile.meta_group_rank()) + start_set_index;
// Keys dump from cache
key_type read_key[set_associativity];
// Lane(thread) offset for storing each key
uint32_t thread_key_offset[set_associativity];
// Warp offset for storing each key
uint32_t warp_key_offset;
// Block offset for storing each key
__shared__ size_t block_key_offset;
// Warp tile dump target slabset
if (set_idx < end_set_index) {
// Lock the slabset before operating the slabset
warp_lock_mutex<warp_size>(warp_tile, set_mutex[set_idx]);
// The warp tile read out the slabset
for (unsigned slab_id = 0; slab_id < set_associativity; slab_id++) {
// The warp tile read out a slab
read_key[slab_id] = ((volatile key_type*)(keys[set_idx].set_[slab_id].slab_))[lane_idx];
}
// Finish dumping the slabset, unlock the slabset
warp_unlock_mutex<warp_size>(warp_tile, set_mutex[set_idx]);
// Each lane(thread) within the warp tile calculate the offset to store its keys
uint32_t warp_tile_total_keys = 0;
for (unsigned slab_id = 0; slab_id < set_associativity; slab_id++) {
unsigned valid_mask = warp_tile.ballot(read_key[slab_id] != empty_key);
thread_key_offset[slab_id] =
__popc(valid_mask & ((1U << lane_idx) - 1U)) + warp_tile_total_keys;
warp_tile_total_keys = warp_tile_total_keys + __popc(valid_mask);
}
// Each warp tile request a unique place from the block-level counter
if (lane_idx == 0) {
warp_key_offset = atomicAdd(&block_acc, warp_tile_total_keys);
}
warp_key_offset = warp_tile.shfl(warp_key_offset, 0);
}
// Each block request a unique place in global memory output buffer
__syncthreads();
if (threadIdx.x == 0) {
block_key_offset = atomicAdd(d_dump_counter, (size_t)block_acc);
}
__syncthreads();
// Warp tile store the (non-empty)keys back to output buffer
if (set_idx < end_set_index) {
for (unsigned slab_id = 0; slab_id < set_associativity; slab_id++) {
if (read_key[slab_id] != empty_key) {
d_keys[block_key_offset + warp_key_offset + thread_key_offset[slab_id]] = read_key[slab_id];
}
}
}
}
#endif
///////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef LIBCUDACXX_VERSION
template <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,
int warp_size, typename set_hasher, typename slab_hasher>
gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,
slab_hasher>::gpu_cache(const size_t capacity_in_set, const size_t embedding_vec_size)
: capacity_in_set_(capacity_in_set), embedding_vec_size_(embedding_vec_size) {
// Check parameter
if (capacity_in_set_ == 0) {
printf("Error: Invalid value for capacity_in_set.\n");
return;
}
if (embedding_vec_size_ == 0) {
printf("Error: Invalid value for embedding_vec_size.\n");
return;
}
if (set_associativity <= 0) {
printf("Error: Invalid value for set_associativity.\n");
return;
}
if (warp_size != 1 && warp_size != 2 && warp_size != 4 && warp_size != 8 && warp_size != 16 &&
warp_size != 32) {
printf("Error: Invalid value for warp_size.\n");
return;
}
// Get the current CUDA dev
CUDA_CHECK(cudaGetDevice(&dev_));
// Calculate # of slot
num_slot_ = capacity_in_set_ * set_associativity * warp_size;
// Allocate GPU memory for cache
CUDA_CHECK(cudaMalloc((void**)&keys_, sizeof(slabset) * capacity_in_set_));
CUDA_CHECK(cudaMalloc((void**)&vals_, sizeof(float) * embedding_vec_size_ * num_slot_));
CUDA_CHECK(cudaMalloc((void**)&slot_counter_, sizeof(ref_counter_type) * num_slot_));
CUDA_CHECK(cudaMalloc((void**)&global_counter_, sizeof(atomic_ref_counter_type)));
// Allocate GPU memory for set mutex
CUDA_CHECK(cudaMalloc((void**)&set_mutex_, sizeof(mutex) * capacity_in_set_));
// Initialize the cache, set all entry to unused <K,V>
init_cache<<<((num_slot_ - 1) / BLOCK_SIZE_) + 1, BLOCK_SIZE_>>>(
keys_, slot_counter_, global_counter_, num_slot_, empty_key, set_mutex_, capacity_in_set_);
// Wait for initialization to finish
CUDA_CHECK(cudaStreamSynchronize(0));
CUDA_CHECK(cudaGetLastError());
}
#else
template <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,
int warp_size, typename set_hasher, typename slab_hasher>
gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,
slab_hasher>::gpu_cache(const size_t capacity_in_set, const size_t embedding_vec_size)
: capacity_in_set_(capacity_in_set), embedding_vec_size_(embedding_vec_size) {
// Check parameter
if (capacity_in_set_ == 0) {
printf("Error: Invalid value for capacity_in_set.\n");
return;
}
if (embedding_vec_size_ == 0) {
printf("Error: Invalid value for embedding_vec_size.\n");
return;
}
if (set_associativity <= 0) {
printf("Error: Invalid value for set_associativity.\n");
return;
}
if (warp_size != 1 && warp_size != 2 && warp_size != 4 && warp_size != 8 && warp_size != 16 &&
warp_size != 32) {
printf("Error: Invalid value for warp_size.\n");
return;
}
// Get the current CUDA dev
CUDA_CHECK(cudaGetDevice(&dev_));
// Calculate # of slot
num_slot_ = capacity_in_set_ * set_associativity * warp_size;
// Allocate GPU memory for cache
CUDA_CHECK(cudaMalloc((void**)&keys_, sizeof(slabset) * capacity_in_set_));
CUDA_CHECK(cudaMalloc((void**)&vals_, sizeof(float) * embedding_vec_size_ * num_slot_));
CUDA_CHECK(cudaMalloc((void**)&slot_counter_, sizeof(ref_counter_type) * num_slot_));
CUDA_CHECK(cudaMalloc((void**)&global_counter_, sizeof(ref_counter_type)));
// Allocate GPU memory for set mutex
CUDA_CHECK(cudaMalloc((void**)&set_mutex_, sizeof(int) * capacity_in_set_));
// Initialize the cache, set all entry to unused <K,V>
init_cache<<<((num_slot_ - 1) / BLOCK_SIZE_) + 1, BLOCK_SIZE_>>>(
keys_, slot_counter_, global_counter_, num_slot_, empty_key, set_mutex_, capacity_in_set_);
// Wait for initialization to finish
CUDA_CHECK(cudaStreamSynchronize(0));
CUDA_CHECK(cudaGetLastError());
}
#endif
#ifdef LIBCUDACXX_VERSION
template <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,
int warp_size, typename set_hasher, typename slab_hasher>
gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,
slab_hasher>::~gpu_cache() {
// Device Restorer
nv::CudaDeviceRestorer dev_restorer;
// Check device
dev_restorer.check_device(dev_);
// Destruct CUDA std object
destruct_kernel<<<((capacity_in_set_ - 1) / BLOCK_SIZE_) + 1, BLOCK_SIZE_>>>(
global_counter_, set_mutex_, capacity_in_set_);
// Wait for destruction to finish
CUDA_CHECK(cudaStreamSynchronize(0));
// Free GPU memory for cache
CUDA_CHECK(cudaFree(keys_));
CUDA_CHECK(cudaFree(vals_));
CUDA_CHECK(cudaFree(slot_counter_));
CUDA_CHECK(cudaFree(global_counter_));
// Free GPU memory for set mutex
CUDA_CHECK(cudaFree(set_mutex_));
}
#else
template <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,
int warp_size, typename set_hasher, typename slab_hasher>
gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,
slab_hasher>::~gpu_cache() noexcept(false) {
// Device Restorer
nv::CudaDeviceRestorer dev_restorer;
// Check device
dev_restorer.check_device(dev_);
// Free GPU memory for cache
CUDA_CHECK(cudaFree(keys_));
CUDA_CHECK(cudaFree(vals_));
CUDA_CHECK(cudaFree(slot_counter_));
CUDA_CHECK(cudaFree(global_counter_));
// Free GPU memory for set mutex
CUDA_CHECK(cudaFree(set_mutex_));
}
#endif
#ifdef LIBCUDACXX_VERSION
template <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,
int warp_size, typename set_hasher, typename slab_hasher>
void gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,
slab_hasher>::Query(const key_type* d_keys, const size_t len, float* d_values,
uint64_t* d_missing_index, key_type* d_missing_keys,
size_t* d_missing_len, cudaStream_t stream,
const size_t task_per_warp_tile) {
// Device Restorer
nv::CudaDeviceRestorer dev_restorer;
// Check device
dev_restorer.check_device(dev_);
// Check if it is a valid query
if (len == 0) {
// Set the d_missing_len to 0 before return
CUDA_CHECK(cudaMemsetAsync(d_missing_len, 0, sizeof(size_t), stream));
return;
}
// Update the global counter as user perform a new(most recent) read operation to the cache
// Resolve distance overflow issue as well.
update_kernel_overflow_ignore<atomic_ref_counter_type>
<<<1, 1, 0, stream>>>(global_counter_, d_missing_len);
// Read from the cache
// Touch and refresh the hitting slot
const size_t keys_per_block = (BLOCK_SIZE_ / warp_size) * task_per_warp_tile;
const size_t grid_size = ((len - 1) / keys_per_block) + 1;
get_kernel<key_type, ref_counter_type, atomic_ref_counter_type, slabset, set_hasher, slab_hasher,
mutex, empty_key, set_associativity, warp_size><<<grid_size, BLOCK_SIZE_, 0, stream>>>(
d_keys, len, d_values, embedding_vec_size_, d_missing_index, d_missing_keys, d_missing_len,
global_counter_, slot_counter_, capacity_in_set_, keys_, vals_, set_mutex_,
task_per_warp_tile);
// Check for GPU error before return
CUDA_CHECK(cudaGetLastError());
}
#else
template <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,
int warp_size, typename set_hasher, typename slab_hasher>
void gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,
slab_hasher>::Query(const key_type* d_keys, const size_t len, float* d_values,
uint64_t* d_missing_index, key_type* d_missing_keys,
size_t* d_missing_len, cudaStream_t stream,
const size_t task_per_warp_tile) {
// Device Restorer
nv::CudaDeviceRestorer dev_restorer;
// Check device
dev_restorer.check_device(dev_);
// Check if it is a valid query
if (len == 0) {
// Set the d_missing_len to 0 before return
CUDA_CHECK(cudaMemsetAsync(d_missing_len, 0, sizeof(size_t), stream));
return;
}
// Update the global counter as user perform a new(most recent) read operation to the cache
// Resolve distance overflow issue as well.
update_kernel_overflow_ignore<ref_counter_type>
<<<1, 1, 0, stream>>>(global_counter_, d_missing_len);
// Read from the cache
// Touch and refresh the hitting slot
const size_t keys_per_block = (BLOCK_SIZE_ / warp_size) * task_per_warp_tile;
const size_t grid_size = ((len - 1) / keys_per_block) + 1;
get_kernel<key_type, ref_counter_type, slabset, set_hasher, slab_hasher, empty_key,
set_associativity, warp_size><<<grid_size, BLOCK_SIZE_, 0, stream>>>(
d_keys, len, d_values, embedding_vec_size_, d_missing_index, d_missing_keys, d_missing_len,
global_counter_, slot_counter_, capacity_in_set_, keys_, vals_, set_mutex_,
task_per_warp_tile);
// Check for GPU error before return
CUDA_CHECK(cudaGetLastError());
}
#endif
#ifdef LIBCUDACXX_VERSION
template <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,
int warp_size, typename set_hasher, typename slab_hasher>
void gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,
slab_hasher>::Replace(const key_type* d_keys, const size_t len,
const float* d_values, cudaStream_t stream,
const size_t task_per_warp_tile) {
// Check if it is a valid replacement
if (len == 0) {
return;
}
// Device Restorer
nv::CudaDeviceRestorer dev_restorer;
// Check device
dev_restorer.check_device(dev_);
// Try to insert the <k,v> paris into the cache as long as there are unused slot
// Then replace the <k,v> pairs into the cache
const size_t keys_per_block = (BLOCK_SIZE_ / warp_size) * task_per_warp_tile;
const size_t grid_size = ((len - 1) / keys_per_block) + 1;
insert_replace_kernel<key_type, slabset, ref_counter_type, mutex, atomic_ref_counter_type,
set_hasher, slab_hasher, empty_key, set_associativity, warp_size>
<<<grid_size, BLOCK_SIZE_, 0, stream>>>(d_keys, d_values, embedding_vec_size_, len, keys_,
vals_, slot_counter_, set_mutex_, global_counter_,
capacity_in_set_, task_per_warp_tile);
// Check for GPU error before return
CUDA_CHECK(cudaGetLastError());
}
#else
template <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,
int warp_size, typename set_hasher, typename slab_hasher>
void gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,
slab_hasher>::Replace(const key_type* d_keys, const size_t len,
const float* d_values, cudaStream_t stream,
const size_t task_per_warp_tile) {
// Check if it is a valid replacement
if (len == 0) {
return;
}
// Device Restorer
nv::CudaDeviceRestorer dev_restorer;
// Check device
dev_restorer.check_device(dev_);
// Try to insert the <k,v> paris into the cache as long as there are unused slot
// Then replace the <k,v> pairs into the cache
const size_t keys_per_block = (BLOCK_SIZE_ / warp_size) * task_per_warp_tile;
const size_t grid_size = ((len - 1) / keys_per_block) + 1;
insert_replace_kernel<key_type, slabset, ref_counter_type, set_hasher, slab_hasher, empty_key,
set_associativity, warp_size><<<grid_size, BLOCK_SIZE_, 0, stream>>>(
d_keys, d_values, embedding_vec_size_, len, keys_, vals_, slot_counter_, set_mutex_,
global_counter_, capacity_in_set_, task_per_warp_tile);
// Check for GPU error before return
CUDA_CHECK(cudaGetLastError());
}
#endif
#ifdef LIBCUDACXX_VERSION
template <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,
int warp_size, typename set_hasher, typename slab_hasher>
void gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,
slab_hasher>::Update(const key_type* d_keys, const size_t len, const float* d_values,
cudaStream_t stream, const size_t task_per_warp_tile) {
// Check if it is a valid update request
if (len == 0) {
return;
}
// Device Restorer
nv::CudaDeviceRestorer dev_restorer;
// Check device
dev_restorer.check_device(dev_);
// Update the value of input keys that are existed in the cache
const size_t keys_per_block = (BLOCK_SIZE_ / warp_size) * task_per_warp_tile;
const size_t grid_size = ((len - 1) / keys_per_block) + 1;
update_kernel<key_type, slabset, set_hasher, slab_hasher, mutex, empty_key, set_associativity,
warp_size><<<grid_size, BLOCK_SIZE_, 0, stream>>>(
d_keys, len, d_values, embedding_vec_size_, capacity_in_set_, keys_, vals_, set_mutex_,
task_per_warp_tile);
// Check for GPU error before return
CUDA_CHECK(cudaGetLastError());
}
#else
template <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,
int warp_size, typename set_hasher, typename slab_hasher>
void gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,
slab_hasher>::Update(const key_type* d_keys, const size_t len, const float* d_values,
cudaStream_t stream, const size_t task_per_warp_tile) {
// Check if it is a valid update request
if (len == 0) {
return;
}
// Device Restorer
nv::CudaDeviceRestorer dev_restorer;
// Check device
dev_restorer.check_device(dev_);
// Update the value of input keys that are existed in the cache
const size_t keys_per_block = (BLOCK_SIZE_ / warp_size) * task_per_warp_tile;
const size_t grid_size = ((len - 1) / keys_per_block) + 1;
update_kernel<key_type, slabset, set_hasher, slab_hasher, empty_key, set_associativity, warp_size>
<<<grid_size, BLOCK_SIZE_, 0, stream>>>(d_keys, len, d_values, embedding_vec_size_,
capacity_in_set_, keys_, vals_, set_mutex_,
task_per_warp_tile);
// Check for GPU error before return
CUDA_CHECK(cudaGetLastError());
}
#endif
#ifdef LIBCUDACXX_VERSION
template <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,
int warp_size, typename set_hasher, typename slab_hasher>
void gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,
slab_hasher>::Dump(key_type* d_keys, size_t* d_dump_counter,
const size_t start_set_index, const size_t end_set_index,
cudaStream_t stream) {
// Check if it is a valid dump request
if (start_set_index >= capacity_in_set_) {
printf("Error: Invalid value for start_set_index. Nothing dumped.\n");
return;
}
if (end_set_index <= start_set_index || end_set_index > capacity_in_set_) {
printf("Error: Invalid value for end_set_index. Nothing dumped.\n");
return;
}
// Device Restorer
nv::CudaDeviceRestorer dev_restorer;
// Check device
dev_restorer.check_device(dev_);
// Set the global counter to 0 first
CUDA_CHECK(cudaMemsetAsync(d_dump_counter, 0, sizeof(size_t), stream));
// Dump keys from the cache
const size_t grid_size =
(((end_set_index - start_set_index) - 1) / (BLOCK_SIZE_ / warp_size)) + 1;
dump_kernel<key_type, slabset, mutex, empty_key, set_associativity, warp_size>
<<<grid_size, BLOCK_SIZE_, 0, stream>>>(d_keys, d_dump_counter, keys_, set_mutex_,
start_set_index, end_set_index);
// Check for GPU error before return
CUDA_CHECK(cudaGetLastError());
}
#else
template <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,
int warp_size, typename set_hasher, typename slab_hasher>
void gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,
slab_hasher>::Dump(key_type* d_keys, size_t* d_dump_counter,
const size_t start_set_index, const size_t end_set_index,
cudaStream_t stream) {
// Check if it is a valid dump request
if (start_set_index >= capacity_in_set_) {
printf("Error: Invalid value for start_set_index. Nothing dumped.\n");
return;
}
if (end_set_index <= start_set_index || end_set_index > capacity_in_set_) {
printf("Error: Invalid value for end_set_index. Nothing dumped.\n");
return;
}
// Device Restorer
nv::CudaDeviceRestorer dev_restorer;
// Check device
dev_restorer.check_device(dev_);
// Set the global counter to 0 first
CUDA_CHECK(cudaMemsetAsync(d_dump_counter, 0, sizeof(size_t), stream));
// Dump keys from the cache
const size_t grid_size =
(((end_set_index - start_set_index) - 1) / (BLOCK_SIZE_ / warp_size)) + 1;
dump_kernel<key_type, slabset, empty_key, set_associativity, warp_size>
<<<grid_size, BLOCK_SIZE_, 0, stream>>>(d_keys, d_dump_counter, keys_, set_mutex_,
start_set_index, end_set_index);
// Check for GPU error before return
CUDA_CHECK(cudaGetLastError());
}
#endif
template class gpu_cache<unsigned int, uint64_t, std::numeric_limits<unsigned int>::max(),
SET_ASSOCIATIVITY, SLAB_SIZE>;
template class gpu_cache<long long, uint64_t, std::numeric_limits<long long>::max(),
SET_ASSOCIATIVITY, SLAB_SIZE>;
} // namespace gpu_cache
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cooperative_groups.h>
#include <cuda.h>
#include <stdint.h>
#include <stdio.h>
#include <static_hash_table.hpp>
namespace gpu_cache {
template <typename T>
__device__ __forceinline__ T atomicCASHelper(T *address, T compare, T val) {
return atomicCAS(address, compare, val);
}
template <>
__device__ __forceinline__ long long atomicCASHelper(long long *address, long long compare,
long long val) {
return (long long)atomicCAS((unsigned long long *)address, (unsigned long long)compare,
(unsigned long long)val);
}
template <>
__device__ __forceinline__ int64_t atomicCASHelper(int64_t *address, int64_t compare, int64_t val) {
return (int64_t)atomicCAS((unsigned long long *)address, (unsigned long long)compare,
(unsigned long long)val);
}
template <unsigned int group_size, typename key_type, typename size_type, typename hasher,
typename CG>
__device__ size_type insert(key_type *table, size_type capacity, key_type key, const hasher &hash,
const CG &cg, const key_type empty_key, const size_type invalid_slot) {
// If insert successfully, return its position in the table,
// otherwise return invalid_slot.
const size_type num_groups = capacity / group_size;
#if (CUDA_VERSION < 11060)
unsigned long long num_threads_per_group = cg.size();
#else
unsigned long long num_threads_per_group = cg.num_threads();
#endif
const unsigned int num_tiles_per_group = group_size / num_threads_per_group;
// Assuming capacity is a power of 2
size_type slot = hash(key) & (capacity - 1);
slot = slot - (slot & (size_type)(group_size - 1)) + cg.thread_rank();
for (size_type step = 0; step < num_groups; ++step) {
for (unsigned int i = 0; i < num_tiles_per_group; ++i) {
key_type existed_key = table[slot];
// Check if key already exists
bool existed = cg.any(existed_key == key);
if (existed) {
return invalid_slot;
}
// Try to insert the target key into empty slot
while (true) {
int can_insert = cg.ballot(existed_key == empty_key);
if (!can_insert) {
break;
}
bool succeed = false;
int src_lane = __ffs(can_insert) - 1;
if (cg.thread_rank() == src_lane) {
key_type old = atomicCASHelper(table + slot, empty_key, key);
if (old == empty_key) {
// Insert key successfully
succeed = true;
} else if (old == key) {
// The target key was inserted by another thread
succeed = true;
slot = invalid_slot;
} else {
// The empty slot was occupied by another key,
// update the existed_key for next loop.
existed_key = old;
}
}
succeed = cg.shfl(succeed, src_lane);
if (succeed) {
slot = cg.shfl(slot, src_lane);
return slot;
}
}
slot += num_threads_per_group;
}
slot = (slot + group_size * step) & (capacity - 1);
}
return invalid_slot;
}
template <unsigned int tile_size, unsigned int group_size, typename key_type, typename size_type,
typename hasher>
__global__ void InsertKeyKernel(key_type *table_keys, size_type *table_indices, size_type capacity,
const key_type *keys, size_type num_keys, size_type offset,
hasher hash, const key_type empty_key,
const size_type invalid_slot) {
static_assert(tile_size <= group_size, "tile_size cannot be larger than group_size");
auto block = cooperative_groups::this_thread_block();
auto tile = cooperative_groups::tiled_partition<tile_size>(block);
int tile_idx = tile.meta_group_size() * block.group_index().x + tile.meta_group_rank();
int tile_cnt = tile.meta_group_size() * gridDim.x;
for (size_type i = tile_idx; i < num_keys; i += tile_cnt) {
key_type key = keys[i];
if (key == empty_key) {
if (tile.thread_rank() == 0 && table_keys[capacity] != empty_key) {
table_keys[capacity] = empty_key;
table_indices[capacity] = i + offset;
}
continue;
}
size_type slot =
insert<group_size>(table_keys, capacity, key, hash, tile, empty_key, invalid_slot);
if (tile.thread_rank() == 0 && slot != invalid_slot) {
table_indices[slot] = i + offset;
}
}
}
template <unsigned int group_size, typename key_type, typename size_type, typename hasher,
typename CG>
__device__ size_type lookup(key_type *table, size_type capacity, key_type key, const hasher &hash,
const CG &cg, const key_type empty_key, const size_type invalid_slot) {
// If lookup successfully, return the target key's position in the table,
// otherwise return invalid_slot.
const size_type num_groups = capacity / group_size;
#if (CUDA_VERSION < 11060)
unsigned long long num_threads_per_group = cg.size();
#else
unsigned long long num_threads_per_group = cg.num_threads();
#endif
const unsigned int num_tiles_per_group = group_size / num_threads_per_group;
// Assuming capacity is a power of 2
size_type slot = hash(key) & (capacity - 1);
slot = slot - (slot & (size_type)(group_size - 1)) + cg.thread_rank();
for (size_type step = 0; step < num_groups; ++step) {
for (unsigned int i = 0; i < num_tiles_per_group; ++i) {
key_type existed_key = table[slot];
// Check if key exists
int existed = cg.ballot(existed_key == key);
if (existed) {
int src_lane = __ffs(existed) - 1;
slot = cg.shfl(slot, src_lane);
return slot;
}
// The target key doesn't exist
bool contain_empty = cg.any(existed_key == empty_key);
if (contain_empty) {
return invalid_slot;
}
slot += num_threads_per_group;
}
slot = (slot + group_size * step) & (capacity - 1);
}
return invalid_slot;
}
template <int warp_size>
__forceinline__ __device__ void warp_tile_copy(const size_t lane_idx,
const size_t emb_vec_size_in_float,
volatile float *d_dst, const float *d_src) {
// 16 bytes align
if (emb_vec_size_in_float % 4 != 0 || (size_t)d_dst % 16 != 0 || (size_t)d_src % 16 != 0) {
#pragma unroll
for (size_t i = lane_idx; i < emb_vec_size_in_float; i += warp_size) {
d_dst[i] = d_src[i];
}
} else {
#pragma unroll
for (size_t i = lane_idx; i < emb_vec_size_in_float / 4; i += warp_size) {
*(float4 *)(d_dst + i * 4) = __ldg((const float4 *)(d_src + i * 4));
}
}
}
template <int warp_size>
__forceinline__ __device__ void warp_tile_copy(const size_t lane_idx,
const size_t emb_vec_size_in_float,
volatile float *d_dst, const float default_value) {
#pragma unroll
for (size_t i = lane_idx; i < emb_vec_size_in_float; i += warp_size) {
d_dst[i] = default_value;
}
}
template <unsigned int tile_size, unsigned int group_size, typename key_type, typename value_type,
typename size_type, typename hasher>
__global__ void LookupKernel(key_type *table_keys, size_type *table_indices, size_type capacity,
const key_type *keys, int num_keys, const value_type *values,
int value_dim, value_type *output, hasher hash,
const key_type empty_key, const value_type default_value,
const size_type invalid_slot) {
static_assert(tile_size <= group_size, "tile_size cannot be larger than group_size");
constexpr int WARP_SIZE = 32;
static_assert(WARP_SIZE % tile_size == 0, "tile_size must be divisible by warp_size");
auto grid = cooperative_groups::this_grid();
auto block = cooperative_groups::this_thread_block();
auto tile = cooperative_groups::tiled_partition<tile_size>(block);
auto warp_tile = cooperative_groups::tiled_partition<WARP_SIZE>(block);
int tile_idx = tile.meta_group_size() * block.group_index().x + tile.meta_group_rank();
int tile_cnt = tile.meta_group_size() * gridDim.x;
for (int it = 0; it < (num_keys - 1) / tile_cnt + 1; it++) {
size_type slot = invalid_slot;
int key_num = it * tile_cnt + tile_idx;
if (key_num < num_keys) {
key_type key = keys[key_num];
if (key == empty_key) {
if (tile.thread_rank() == 0 && table_keys[capacity] == key) {
slot = capacity;
}
} else {
slot = lookup<group_size>(table_keys, capacity, key, hash, tile, empty_key, invalid_slot);
}
}
for (int i = 0; i < WARP_SIZE / tile_size; i++) {
auto slot_to_read = warp_tile.shfl(slot, i * tile_size);
int idx_to_write = warp_tile.shfl(key_num, 0) + i;
if (idx_to_write >= num_keys) break;
if (slot_to_read == invalid_slot) {
warp_tile_copy<WARP_SIZE>(warp_tile.thread_rank(), value_dim,
output + (size_t)value_dim * idx_to_write, default_value);
continue;
}
auto index = table_indices[slot_to_read];
warp_tile_copy<WARP_SIZE>(warp_tile.thread_rank(), value_dim,
output + (size_t)value_dim * idx_to_write,
values + (size_t)value_dim * index);
}
}
}
template <typename key_type, typename value_type, unsigned int tile_size, unsigned int group_size,
typename hasher>
StaticHashTable<key_type, value_type, tile_size, group_size, hasher>::StaticHashTable(
size_type capacity, int value_dim, hasher hash)
: table_keys_(nullptr),
table_indices_(nullptr),
key_capacity_(capacity * 2),
table_values_(nullptr),
value_capacity_(capacity),
value_dim_(value_dim),
size_(0),
hash_(hash) {
// Check parameters
if (capacity <= 0) {
printf("Error: capacity must be larger than 0\n");
exit(EXIT_FAILURE);
}
if (value_dim <= 0) {
printf("Error: value_dim must be larger than 0\n");
exit(EXIT_FAILURE);
}
// Make key_capacity_ be a power of 2
size_t new_capacity = group_size;
while (new_capacity < key_capacity_) {
new_capacity *= 2;
}
key_capacity_ = new_capacity;
// Allocate device memory
size_t align_m = 16;
size_t num_keys = key_capacity_ + 1;
size_t num_values = (value_capacity_ * value_dim_ + align_m - 1) / align_m * align_m;
CUDA_CHECK(cudaMalloc(&table_keys_, sizeof(key_type) * num_keys));
CUDA_CHECK(cudaMalloc(&table_indices_, sizeof(size_type) * num_keys));
CUDA_CHECK(cudaMalloc(&table_values_, sizeof(value_type) * num_values));
// Initialize table_keys_
CUDA_CHECK(cudaMemset(table_keys_, 0xff, sizeof(key_type) * key_capacity_));
CUDA_CHECK(cudaMemset(table_keys_ + key_capacity_, 0, sizeof(key_type)));
}
template <typename key_type, typename value_type, unsigned int tile_size, unsigned int group_size,
typename hasher>
void StaticHashTable<key_type, value_type, tile_size, group_size, hasher>::insert(
const key_type *keys, const value_type *values, size_type num_keys, cudaStream_t stream) {
if (num_keys == 0) {
return;
}
if (num_keys <= 0 || (size() + num_keys) > capacity()) {
printf("Error: Invalid num_keys to insert\n");
exit(EXIT_FAILURE);
}
// Insert keys
constexpr int block = 256;
int grid = (num_keys - 1) / block + 1;
InsertKeyKernel<tile_size, group_size>
<<<grid, block, 0, stream>>>(table_keys_, table_indices_, key_capacity_, keys, num_keys,
size_, hash_, empty_key, invalid_slot);
// Copy values
CUDA_CHECK(cudaMemcpyAsync(table_values_ + size_ * value_dim_, values,
sizeof(value_type) * num_keys * value_dim_, cudaMemcpyDeviceToDevice,
stream));
size_ += num_keys;
}
template <typename key_type, typename value_type, unsigned int tile_size, unsigned int group_size,
typename hasher>
void StaticHashTable<key_type, value_type, tile_size, group_size, hasher>::clear(
cudaStream_t stream) {
CUDA_CHECK(cudaMemsetAsync(table_keys_, 0xff, sizeof(key_type) * key_capacity_, stream));
CUDA_CHECK(cudaMemsetAsync(table_keys_ + key_capacity_, 0, sizeof(key_type), stream));
size_ = 0;
}
template <typename key_type, typename value_type, unsigned int tile_size, unsigned int group_size,
typename hasher>
StaticHashTable<key_type, value_type, tile_size, group_size, hasher>::~StaticHashTable() {
CUDA_CHECK(cudaFree(table_keys_));
CUDA_CHECK(cudaFree(table_indices_));
CUDA_CHECK(cudaFree(table_values_));
}
template <typename key_type, typename value_type, unsigned int tile_size, unsigned int group_size,
typename hasher>
void StaticHashTable<key_type, value_type, tile_size, group_size, hasher>::lookup(
const key_type *keys, value_type *values, int num_keys, value_type default_value,
cudaStream_t stream) {
if (num_keys == 0) {
return;
}
constexpr int block = 256;
const int grid = (num_keys - 1) / block + 1;
// Lookup keys
LookupKernel<tile_size, group_size><<<grid, block, 0, stream>>>(
table_keys_, table_indices_, key_capacity_, keys, num_keys, table_values_, value_dim_, values,
hash_, empty_key, default_value, invalid_slot);
}
template class StaticHashTable<long long, float>;
template class StaticHashTable<uint32_t, float>;
} // namespace gpu_cache
\ No newline at end of file
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cooperative_groups.h>
#include <nv_util.h>
#include <iostream>
#include <static_hash_table.hpp>
#include <static_table.hpp>
namespace gpu_cache {
template <typename key_type>
static_table<key_type>::static_table(const size_t table_size, const size_t embedding_vec_size,
const float default_value)
: table_size_(table_size),
embedding_vec_size_(embedding_vec_size),
default_value_(default_value),
static_hash_table_(table_size, embedding_vec_size) {
if (embedding_vec_size_ == 0) {
printf("Error: Invalid value for embedding_vec_size.\n");
return;
}
}
template <typename key_type>
void static_table<key_type>::Query(const key_type* d_keys, const size_t len, float* d_values,
cudaStream_t stream) {
static_hash_table_.lookup(d_keys, d_values, len, default_value_, stream);
}
template <typename key_type>
void static_table<key_type>::Init(const key_type* d_keys, const size_t len, const float* d_values,
cudaStream_t stream) {
static_hash_table_.insert(d_keys, d_values, len, stream);
}
template <typename key_type>
void static_table<key_type>::Clear(cudaStream_t stream) {
static_hash_table_.clear(stream);
}
template class static_table<unsigned int>;
template class static_table<long long>;
} // namespace gpu_cache
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cooperative_groups.h>
#include <cuda_runtime_api.h>
#include <immintrin.h>
#include <atomic>
#include <iostream>
#include <limits>
#include <mutex>
#include <uvm_table.hpp>
namespace cg = cooperative_groups;
namespace {
constexpr int set_size = 4;
constexpr int block_size = 256;
template <typename key_type>
__host__ __device__ key_type hash(key_type key) {
return key;
}
template <typename key_type>
__global__ void hash_add_kernel(const key_type* new_keys, const int num_keys, key_type* keys,
const int num_sets, int* set_sizes, const int max_set_size,
key_type* missing_keys, int* num_missing_keys) {
__shared__ key_type s_missing_keys[block_size];
__shared__ int s_missing_count;
__shared__ size_t s_missing_idx;
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
if (block.thread_rank() == 0) {
s_missing_count = 0;
}
block.sync();
size_t idx = grid.thread_rank();
if (idx < num_keys) {
auto key = new_keys[idx];
size_t idx_set = hash(key) % num_sets;
int prev_set_size = atomicAdd(&set_sizes[idx_set], 1);
if (prev_set_size < max_set_size) {
keys[idx_set * max_set_size + prev_set_size] = key;
} else {
int count = atomicAdd(&s_missing_count, 1);
s_missing_keys[count] = key;
}
}
block.sync();
if (block.thread_rank() == 0) {
s_missing_idx = atomicAdd(num_missing_keys, s_missing_count);
}
block.sync();
for (size_t i = block.thread_rank(); i < s_missing_count; i += block.num_threads()) {
missing_keys[s_missing_idx + i] = s_missing_keys[i];
}
}
template <typename key_type, typename index_type>
__global__ void hash_query_kernel(const key_type* query_keys, int* num_keys_ptr,
const key_type* keys, const size_t num_sets,
const int max_set_size, index_type* output_indices) {
constexpr int tile_size = set_size;
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto tile = cg::tiled_partition<tile_size>(block);
int num_keys = *num_keys_ptr;
if (num_keys == 0) return;
#if (CUDA_VERSION < 11060)
size_t num_threads_per_grid = grid.size();
#else
size_t num_threads_per_grid = grid.num_threads();
#endif
size_t step = (num_keys - 1) / num_threads_per_grid + 1;
for (size_t i = 0; i < step; i++) {
size_t idx = i * num_threads_per_grid + grid.thread_rank();
key_type query_key = std::numeric_limits<key_type>::max();
if (idx < num_keys) {
query_key = query_keys[idx];
}
auto idx_set = hash(query_key) % num_sets;
for (int j = 0; j < tile_size; j++) {
auto current_idx_set = tile.shfl(idx_set, j);
auto current_query_key = tile.shfl(query_key, j);
if (current_query_key == std::numeric_limits<key_type>::max()) {
continue;
}
auto candidate_key = keys[current_idx_set * set_size + tile.thread_rank()];
int existed = tile.ballot(current_query_key == candidate_key);
auto current_idx = tile.shfl(idx, 0) + j;
if (existed) {
int src_lane = __ffs(existed) - 1;
size_t found_idx = current_idx_set * set_size + src_lane;
output_indices[current_idx] = num_sets * src_lane + current_idx_set;
} else {
output_indices[current_idx] = std::numeric_limits<index_type>::max();
}
}
}
}
template <typename key_type, typename index_type>
__global__ void hash_query_kernel(const key_type* query_keys, const int num_keys,
const key_type* keys, const size_t num_sets,
const int max_set_size, index_type* output_indices,
key_type* missing_keys, int* missing_positions,
int* missing_count) {
__shared__ key_type s_missing_keys[block_size];
__shared__ key_type s_missing_positions[block_size];
__shared__ int s_missing_count;
__shared__ int s_missing_idx;
constexpr int tile_size = set_size;
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto tile = cg::tiled_partition<tile_size>(block);
if (block.thread_rank() == 0) {
s_missing_count = 0;
}
block.sync();
size_t idx = grid.thread_rank();
key_type query_key = std::numeric_limits<key_type>::max();
if (idx < num_keys) {
query_key = query_keys[idx];
}
auto idx_set = hash(query_key) % num_sets;
for (int j = 0; j < tile_size; j++) {
auto current_idx_set = tile.shfl(idx_set, j);
auto current_query_key = tile.shfl(query_key, j);
if (current_query_key == std::numeric_limits<key_type>::max()) {
continue;
}
auto candidate_key = keys[current_idx_set * set_size + tile.thread_rank()];
int existed = tile.ballot(current_query_key == candidate_key);
if (existed) {
int src_lane = __ffs(existed) - 1;
size_t found_idx = current_idx_set * set_size + src_lane;
output_indices[tile.shfl(idx, 0) + j] = num_sets * src_lane + current_idx_set;
} else {
auto current_idx = tile.shfl(idx, 0) + j;
output_indices[current_idx] = std::numeric_limits<index_type>::max();
if (tile.thread_rank() == 0) {
int s_count = atomicAdd(&s_missing_count, 1);
s_missing_keys[s_count] = current_query_key;
s_missing_positions[s_count] = current_idx;
}
}
}
if (missing_keys == nullptr) {
if (grid.thread_rank() == 0 && missing_count) {
*missing_count = 0;
}
return;
}
block.sync();
if (block.thread_rank() == 0) {
s_missing_idx = atomicAdd(missing_count, s_missing_count);
}
block.sync();
for (size_t i = block.thread_rank(); i < s_missing_count; i += block.num_threads()) {
missing_keys[s_missing_idx + i] = s_missing_keys[i];
missing_positions[s_missing_idx + i] = s_missing_positions[i];
}
}
template <int warp_size>
__forceinline__ __device__ void warp_tile_copy(const size_t lane_idx,
const size_t emb_vec_size_in_float,
volatile float* d_dst, const float* d_src) {
// 16 bytes align
if (emb_vec_size_in_float % 4 != 0 || (size_t)d_dst % 16 != 0 || (size_t)d_src % 16 != 0) {
#pragma unroll
for (size_t i = lane_idx; i < emb_vec_size_in_float; i += warp_size) {
d_dst[i] = d_src[i];
}
} else {
#pragma unroll
for (size_t i = lane_idx; i < emb_vec_size_in_float / 4; i += warp_size) {
*(float4*)(d_dst + i * 4) = __ldg((const float4*)(d_src + i * 4));
}
}
}
template <typename index_type, typename vec_type>
__global__ void read_vectors_kernel(const index_type* query_indices, const int num_keys,
const vec_type* vectors, const int vec_size,
vec_type* output_vectors) {
constexpr int warp_size = 32;
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto tile = cg::tiled_partition<warp_size>(block);
#if (CUDA_VERSION < 11060)
auto num_threads_per_grid = grid.size();
#else
auto num_threads_per_grid = grid.num_threads();
#endif
for (int step = 0; step < (num_keys - 1) / num_threads_per_grid + 1; step++) {
int key_num = step * num_threads_per_grid + grid.thread_rank();
index_type idx = std::numeric_limits<index_type>::max();
if (key_num < num_keys) {
idx = query_indices[key_num];
}
#pragma unroll 4
for (size_t j = 0; j < warp_size; j++) {
index_type current_idx = tile.shfl(idx, j);
index_type idx_write = tile.shfl(key_num, 0) + j;
if (current_idx == std::numeric_limits<index_type>::max()) continue;
warp_tile_copy<warp_size>(tile.thread_rank(), vec_size, output_vectors + idx_write * vec_size,
vectors + current_idx * vec_size);
}
}
}
template <typename index_type, typename vec_type>
__global__ void distribute_vectors_kernel(const index_type* postions, const size_t num_keys,
const vec_type* vectors, const int vec_size,
vec_type* output_vectors) {
constexpr int warp_size = 32;
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto tile = cg::tiled_partition<warp_size>(block);
#if (CUDA_VERSION < 11060)
auto num_threads_per_grid = grid.size();
#else
auto num_threads_per_grid = grid.num_threads();
#endif
for (size_t step = 0; step < (num_keys - 1) / num_threads_per_grid + 1; step++) {
size_t key_num = step * num_threads_per_grid + grid.thread_rank();
index_type idx = std::numeric_limits<index_type>::max();
if (key_num < num_keys) {
idx = postions[key_num];
}
#pragma unroll 4
for (size_t j = 0; j < warp_size; j++) {
size_t idx_write = tile.shfl(idx, j);
size_t idx_read = tile.shfl(key_num, 0) + j;
if (idx_write == std::numeric_limits<index_type>::max()) continue;
warp_tile_copy<warp_size>(tile.thread_rank(), vec_size,
output_vectors + (size_t)idx_write * vec_size,
vectors + (size_t)idx_read * vec_size);
}
}
}
} // namespace
namespace gpu_cache {
template <typename key_type, typename index_type, typename vec_type>
UvmTable<key_type, index_type, vec_type>::UvmTable(const size_t device_table_capacity,
const size_t host_table_capacity,
const int max_batch_size, const int vec_size,
const vec_type default_value)
: max_batch_size_(std::max(100000, max_batch_size)),
vec_size_(vec_size),
num_set_((device_table_capacity - 1) / set_size + 1),
num_host_set_((host_table_capacity - 1) / set_size + 1),
table_capacity_(num_set_ * set_size),
default_vector_(vec_size, default_value),
device_table_(device_table_capacity, set_size, max_batch_size_),
host_table_(host_table_capacity * 1.3, set_size, max_batch_size_) {
CUDA_CHECK(cudaMalloc(&d_keys_buffer_, sizeof(key_type) * max_batch_size_));
CUDA_CHECK(cudaMalloc(&d_vectors_buffer_, sizeof(vec_type) * max_batch_size_ * vec_size_));
CUDA_CHECK(cudaMalloc(&d_vectors_, sizeof(vec_type) * device_table_.capacity * vec_size_));
CUDA_CHECK(cudaMalloc(&d_output_indices_, sizeof(index_type) * max_batch_size_));
CUDA_CHECK(cudaMalloc(&d_output_host_indices_, sizeof(index_type) * max_batch_size_));
CUDA_CHECK(cudaMallocHost(&h_output_host_indices_, sizeof(index_type) * max_batch_size_));
CUDA_CHECK(cudaMalloc(&d_missing_keys_, sizeof(key_type) * max_batch_size_));
CUDA_CHECK(cudaMalloc(&d_missing_positions_, sizeof(int) * max_batch_size_));
CUDA_CHECK(cudaMalloc(&d_missing_count_, sizeof(int)));
CUDA_CHECK(cudaMemset(d_missing_count_, 0, sizeof(int)));
CUDA_CHECK(cudaStreamCreate(&query_stream_));
for (int i = 0; i < num_buffers_; i++) {
int batch_size_per_buffer = ceil(1.0 * max_batch_size_ / num_buffers_);
CUDA_CHECK(
cudaMallocHost(&h_cpy_buffers_[i], sizeof(vec_type) * batch_size_per_buffer * vec_size));
CUDA_CHECK(cudaMalloc(&d_cpy_buffers_[i], sizeof(vec_type) * batch_size_per_buffer * vec_size));
CUDA_CHECK(cudaStreamCreate(&cpy_streams_[i]));
CUDA_CHECK(cudaEventCreate(&cpy_events_[i]));
}
CUDA_CHECK(cudaMallocHost(&h_missing_keys_, sizeof(key_type) * max_batch_size_));
CUDA_CHECK(cudaEventCreate(&query_event_));
h_vectors_.resize(host_table_.capacity * vec_size_);
}
template <typename key_type, typename index_type, typename vec_type>
void UvmTable<key_type, index_type, vec_type>::add(const key_type* h_keys,
const vec_type* h_vectors,
const size_t num_keys) {
std::vector<key_type> h_missing_keys;
size_t num_batches = (num_keys - 1) / max_batch_size_ + 1;
for (size_t i = 0; i < num_batches; i++) {
size_t this_batch_size =
i != num_batches - 1 ? max_batch_size_ : num_keys - i * max_batch_size_;
CUDA_CHECK(cudaMemcpy(d_keys_buffer_, h_keys + i * max_batch_size_,
sizeof(*d_keys_buffer_) * this_batch_size, cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemset(d_missing_count_, 0, sizeof(*d_missing_count_)));
device_table_.add(d_keys_buffer_, this_batch_size, d_missing_keys_, d_missing_count_, 0);
CUDA_CHECK(cudaDeviceSynchronize());
int num_missing_keys;
CUDA_CHECK(cudaMemcpy(&num_missing_keys, d_missing_count_, sizeof(num_missing_keys),
cudaMemcpyDeviceToHost));
size_t prev_size = h_missing_keys.size();
h_missing_keys.resize(prev_size + num_missing_keys);
CUDA_CHECK(cudaMemcpy(h_missing_keys.data() + prev_size, d_missing_keys_,
sizeof(*d_missing_keys_) * num_missing_keys, cudaMemcpyDeviceToHost));
}
std::vector<key_type> h_final_missing_keys;
num_batches = h_missing_keys.size() ? (h_missing_keys.size() - 1) / max_batch_size_ + 1 : 0;
for (size_t i = 0; i < num_batches; i++) {
size_t this_batch_size =
i != num_batches - 1 ? max_batch_size_ : h_missing_keys.size() - i * max_batch_size_;
CUDA_CHECK(cudaMemcpy(d_keys_buffer_, h_missing_keys.data() + i * max_batch_size_,
sizeof(*d_keys_buffer_) * this_batch_size, cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemset(d_missing_count_, 0, sizeof(*d_missing_count_)));
host_table_.add(d_keys_buffer_, this_batch_size, d_missing_keys_, d_missing_count_, 0);
CUDA_CHECK(cudaDeviceSynchronize());
int num_missing_keys;
CUDA_CHECK(cudaMemcpy(&num_missing_keys, d_missing_count_, sizeof(num_missing_keys),
cudaMemcpyDeviceToHost));
size_t prev_size = h_final_missing_keys.size();
h_final_missing_keys.resize(prev_size + num_missing_keys);
CUDA_CHECK(cudaMemcpy(h_final_missing_keys.data() + prev_size, d_missing_keys_,
sizeof(*d_missing_keys_) * num_missing_keys, cudaMemcpyDeviceToHost));
}
std::vector<key_type> h_keys_buffer(max_batch_size_);
std::vector<index_type> h_indices_buffer(max_batch_size_);
std::vector<int> h_positions_buffer(max_batch_size_);
num_batches = (num_keys - 1) / max_batch_size_ + 1;
size_t num_hit_keys = 0;
for (size_t i = 0; i < num_batches; i++) {
size_t this_batch_size =
i != num_batches - 1 ? max_batch_size_ : num_keys - i * max_batch_size_;
CUDA_CHECK(cudaMemcpy(d_keys_buffer_, h_keys + i * max_batch_size_,
sizeof(*d_keys_buffer_) * this_batch_size, cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemset(d_missing_count_, 0, sizeof(*d_missing_count_)));
device_table_.query(d_keys_buffer_, this_batch_size, d_output_indices_, d_missing_keys_,
d_missing_positions_, d_missing_count_, 0);
CUDA_CHECK(cudaStreamSynchronize(0));
CUDA_CHECK(cudaMemcpy(d_vectors_buffer_, h_vectors + i * max_batch_size_ * vec_size_,
sizeof(*d_vectors_) * this_batch_size * vec_size_,
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaStreamSynchronize(0));
if (num_hit_keys < device_table_.capacity) {
distribute_vectors_kernel<<<(this_batch_size - 1) / block_size + 1, block_size, 0, 0>>>(
d_output_indices_, this_batch_size, d_vectors_buffer_, vec_size_, d_vectors_);
CUDA_CHECK(cudaStreamSynchronize(0));
}
int num_missing_keys;
CUDA_CHECK(cudaMemcpy(&num_missing_keys, d_missing_count_, sizeof(num_missing_keys),
cudaMemcpyDeviceToHost));
num_hit_keys += this_batch_size - num_missing_keys;
host_table_.query(d_missing_keys_, num_missing_keys, d_output_indices_, nullptr, nullptr,
nullptr, 0);
CUDA_CHECK(cudaMemcpy(h_keys_buffer.data(), d_missing_keys_,
sizeof(*d_missing_keys_) * num_missing_keys, cudaMemcpyDeviceToHost))
CUDA_CHECK(cudaMemcpy(h_indices_buffer.data(), d_output_indices_,
sizeof(*d_output_indices_) * num_missing_keys, cudaMemcpyDeviceToHost))
CUDA_CHECK(cudaMemcpy(h_positions_buffer.data(), d_missing_positions_,
sizeof(*d_missing_positions_) * num_missing_keys, cudaMemcpyDeviceToHost))
for (int j = 0; j < num_missing_keys; j++) {
if (h_indices_buffer[j] != std::numeric_limits<index_type>::max()) {
memcpy(h_vectors_.data() + h_indices_buffer[j] * vec_size_,
h_vectors + (i * max_batch_size_ + h_positions_buffer[j]) * vec_size_,
sizeof(*h_vectors) * vec_size_);
} else {
size_t prev_idx = h_vectors_.size() / vec_size_;
h_final_missing_items_.emplace(h_keys_buffer[j], prev_idx);
h_vectors_.resize(h_vectors_.size() + vec_size_);
memcpy(h_vectors_.data() + prev_idx * vec_size_,
h_vectors + (i * max_batch_size_ + h_positions_buffer[j]) * vec_size_,
sizeof(*h_vectors) * vec_size_);
}
}
}
CUDA_CHECK(cudaMemset(d_missing_count_, 0, sizeof(*d_missing_count_)));
}
template <typename key_type, typename index_type, typename vec_type>
void UvmTable<key_type, index_type, vec_type>::query(const key_type* d_keys, const int num_keys,
vec_type* d_vectors, cudaStream_t stream) {
if (!num_keys) return;
CUDA_CHECK(cudaEventRecord(query_event_, stream));
CUDA_CHECK(cudaStreamWaitEvent(query_stream_, query_event_));
static_assert(num_buffers_ >= 2);
device_table_.query(d_keys, num_keys, d_output_indices_, d_missing_keys_, d_missing_positions_,
d_missing_count_, query_stream_);
CUDA_CHECK(cudaEventRecord(query_event_, query_stream_));
CUDA_CHECK(cudaStreamWaitEvent(cpy_streams_[0], query_event_));
int num_missing_keys;
CUDA_CHECK(cudaMemcpyAsync(&num_missing_keys, d_missing_count_, sizeof(*d_missing_count_),
cudaMemcpyDeviceToHost, cpy_streams_[0]));
host_table_.query(d_missing_keys_, d_missing_count_, d_output_host_indices_, query_stream_);
CUDA_CHECK(cudaStreamSynchronize(cpy_streams_[0]));
CUDA_CHECK(cudaMemsetAsync(d_missing_count_, 0, sizeof(*d_missing_count_), query_stream_));
CUDA_CHECK(cudaMemcpyAsync(h_output_host_indices_, d_output_host_indices_,
sizeof(index_type) * num_missing_keys, cudaMemcpyDeviceToHost,
query_stream_));
CUDA_CHECK(cudaMemcpyAsync(h_missing_keys_, d_missing_keys_, sizeof(key_type) * num_missing_keys,
cudaMemcpyDeviceToHost, cpy_streams_[0]));
read_vectors_kernel<<<(num_keys - 1) / block_size + 1, block_size, 0, cpy_streams_[1]>>>(
d_output_indices_, num_keys, d_vectors_, vec_size_, d_vectors);
CUDA_CHECK(cudaStreamSynchronize(query_stream_));
CUDA_CHECK(cudaStreamSynchronize(cpy_streams_[0]));
int num_keys_per_buffer = ceil(1.0 * num_missing_keys / num_buffers_);
for (int buffer_num = 0; buffer_num < num_buffers_; buffer_num++) {
int num_keys_this_buffer = buffer_num != num_buffers_ - 1
? num_keys_per_buffer
: num_missing_keys - num_keys_per_buffer * buffer_num;
if (!num_keys_this_buffer) break;
#pragma omp parallel for num_threads(8)
for (size_t i = 0; i < static_cast<size_t>(num_keys_this_buffer); i++) {
size_t idx_key = buffer_num * num_keys_per_buffer + i;
index_type index = h_output_host_indices_[idx_key];
if (index == std::numeric_limits<index_type>::max()) {
key_type key = h_missing_keys_[idx_key];
auto iterator = h_final_missing_items_.find(key);
if (iterator != h_final_missing_items_.end()) {
index = iterator->second;
}
}
if (index != std::numeric_limits<index_type>::max()) {
memcpy(h_cpy_buffers_[buffer_num] + i * vec_size_, h_vectors_.data() + index * vec_size_,
sizeof(vec_type) * vec_size_);
} else {
memcpy(h_cpy_buffers_[buffer_num] + i * vec_size_, default_vector_.data(),
sizeof(vec_type) * vec_size_);
}
}
CUDA_CHECK(cudaMemcpyAsync(d_cpy_buffers_[buffer_num], h_cpy_buffers_[buffer_num],
sizeof(vec_type) * num_keys_this_buffer * vec_size_,
cudaMemcpyHostToDevice, cpy_streams_[buffer_num]));
distribute_vectors_kernel<<<(num_keys_this_buffer - 1) / block_size + 1, block_size, 0,
cpy_streams_[buffer_num]>>>(
d_missing_positions_ + buffer_num * num_keys_per_buffer, num_keys_this_buffer,
d_cpy_buffers_[buffer_num], vec_size_, d_vectors);
}
for (int i = 0; i < num_buffers_; i++) {
CUDA_CHECK(cudaEventRecord(cpy_events_[i], cpy_streams_[i]));
CUDA_CHECK(cudaStreamWaitEvent(stream, cpy_events_[i]));
}
}
template <typename key_type, typename index_type, typename vec_type>
void UvmTable<key_type, index_type, vec_type>::clear(cudaStream_t stream) {
device_table_.clear(stream);
host_table_.clear(stream);
}
template <typename key_type, typename index_type, typename vec_type>
UvmTable<key_type, index_type, vec_type>::~UvmTable() {
CUDA_CHECK(cudaFree(d_keys_buffer_));
CUDA_CHECK(cudaFree(d_vectors_buffer_));
CUDA_CHECK(cudaFree(d_vectors_));
CUDA_CHECK(cudaFree(d_output_indices_));
CUDA_CHECK(cudaFree(d_output_host_indices_));
CUDA_CHECK(cudaFreeHost(h_output_host_indices_));
CUDA_CHECK(cudaFree(d_missing_keys_));
CUDA_CHECK(cudaFree(d_missing_positions_));
CUDA_CHECK(cudaFree(d_missing_count_));
CUDA_CHECK(cudaFreeHost(h_missing_keys_));
CUDA_CHECK(cudaStreamDestroy(query_stream_));
CUDA_CHECK(cudaEventDestroy(query_event_));
for (int i = 0; i < num_buffers_; i++) {
CUDA_CHECK(cudaFreeHost(h_cpy_buffers_[i]));
CUDA_CHECK(cudaFree(d_cpy_buffers_[i]));
CUDA_CHECK(cudaStreamDestroy(cpy_streams_[i]));
CUDA_CHECK(cudaEventDestroy(cpy_events_[i]));
}
}
template <typename key_type, typename index_type>
HashBlock<key_type, index_type>::HashBlock(size_t expected_capacity, int set_size, int batch_size)
: max_set_size_(set_size), batch_size_(batch_size) {
if (expected_capacity) {
num_sets = (expected_capacity - 1) / set_size + 1;
} else {
num_sets = 10000;
}
capacity = num_sets * set_size;
CUDA_CHECK(cudaMalloc(&keys, sizeof(*keys) * capacity));
CUDA_CHECK(cudaMalloc(&set_sizes_, sizeof(*set_sizes_) * num_sets));
CUDA_CHECK(cudaMemset(set_sizes_, 0, sizeof(*set_sizes_) * num_sets));
}
template <typename key_type, typename index_type>
HashBlock<key_type, index_type>::~HashBlock() {
CUDA_CHECK(cudaFree(keys));
CUDA_CHECK(cudaFree(set_sizes_));
}
template <typename key_type, typename index_type>
void HashBlock<key_type, index_type>::query(const key_type* query_keys, const size_t num_keys,
index_type* output_indices, key_type* missing_keys,
int* missing_positions, int* num_missing_keys,
cudaStream_t stream) {
if (num_keys == 0) {
return;
}
size_t num_batches = (num_keys - 1) / batch_size_ + 1;
for (size_t i = 0; i < num_batches; i++) {
size_t this_batch_size = i != num_batches - 1 ? batch_size_ : num_keys - i * batch_size_;
hash_query_kernel<<<(this_batch_size - 1) / block_size + 1, block_size, 0, stream>>>(
query_keys, this_batch_size, keys, num_sets, max_set_size_, output_indices, missing_keys,
missing_positions, num_missing_keys);
}
}
template <typename key_type, typename index_type>
void HashBlock<key_type, index_type>::query(const key_type* query_keys, int* num_keys,
index_type* output_indices, cudaStream_t stream) {
hash_query_kernel<<<128, 64, 0, stream>>>(query_keys, num_keys, keys, num_sets, max_set_size_,
output_indices);
}
template <typename key_type, typename index_type>
void HashBlock<key_type, index_type>::add(const key_type* new_keys, const size_t num_keys,
key_type* missing_keys, int* num_missing_keys,
cudaStream_t stream) {
if (num_keys == 0) {
return;
}
size_t num_batches = (num_keys - 1) / batch_size_ + 1;
for (size_t i = 0; i < num_batches; i++) {
size_t this_batch_size = i != num_batches - 1 ? batch_size_ : num_keys - i * batch_size_;
hash_add_kernel<<<(this_batch_size - 1) / block_size + 1, block_size, 0, stream>>>(
new_keys + i * this_batch_size, this_batch_size, keys, num_sets, set_sizes_, max_set_size_,
missing_keys, num_missing_keys);
}
}
template <typename key_type, typename index_type>
void HashBlock<key_type, index_type>::clear(cudaStream_t stream) {
CUDA_CHECK(cudaMemsetAsync(set_sizes_, 0, sizeof(*set_sizes_) * num_sets, stream));
}
template class HashBlock<int, size_t>;
template class HashBlock<int64_t, size_t>;
template class HashBlock<size_t, size_t>;
template class HashBlock<unsigned int, size_t>;
template class HashBlock<long long, size_t>;
template class UvmTable<int, size_t>;
template class UvmTable<int64_t, size_t>;
template class UvmTable<size_t, size_t>;
template class UvmTable<unsigned int, size_t>;
template class UvmTable<long long, size_t>;
} // namespace gpu_cache
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment