Unverified Commit f8933bbf authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[Common] Optimize KV cache related kernels (#1914)



* optimize kv_cache reindex and copy kernels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* avoid reindexing from python side
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename variable from previous commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 5350f277
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
namespace transformer_engine { namespace transformer_engine {
namespace kv_cache { namespace kv_cache {
constexpr int block_size = 1024;
template <typename dtype> template <typename dtype>
__global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *batch_indices, __global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *batch_indices,
int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k, int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k,
...@@ -22,21 +24,29 @@ __global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *bat ...@@ -22,21 +24,29 @@ __global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *bat
actual_b = i + 1; actual_b = i + 1;
} }
} }
bool flag = (batch_indices[0] != 0);
for (int batch_idx = 0; batch_idx < actual_b; batch_idx++) { for (int batch_idx = 0; batch_idx < actual_b; batch_idx++) {
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; if (flag || ((batch_indices[batch_idx] - batch_indices[0]) != batch_idx)) {
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; int num_tokens = (cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]) -
for (int token_idx = blockIdx.x; token_idx < cached_len - new_len; token_idx += gridDim.x) { (cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]);
int num_elts_k = h_kv * d_k; int num_elts_k = h_kv * d_k;
int num_elts_v = h_kv * d_v; int num_elts_v = h_kv * d_v;
int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k; int num_elts = max(num_elts_k, num_elts_v);
int k_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_k; for (int token_idx = blockIdx.x; token_idx < num_tokens; token_idx += gridDim.x) {
int v_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_v; int src_offset = batch_indices[batch_idx] * max_seq_len + token_idx;
int v_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_v; int des_offset = batch_idx * max_seq_len + token_idx;
for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) { dtype *k_cache_src_offset = k_cache + src_offset * num_elts_k;
*(k_cache + k_cache_des_offset + i) = *(k_cache + k_cache_src_offset + i); dtype *k_cache_des_offset = k_cache + des_offset * num_elts_k;
dtype *v_cache_src_offset = v_cache + src_offset * num_elts_v;
dtype *v_cache_des_offset = v_cache + des_offset * num_elts_v;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
if (i < num_elts_k) {
*(k_cache_des_offset + i) = *(k_cache_src_offset + i);
}
if (i < num_elts_v) {
*(v_cache_des_offset + i) = *(v_cache_src_offset + i);
}
} }
for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) {
*(v_cache + v_cache_des_offset + i) = *(v_cache + v_cache_src_offset + i);
} }
} }
} }
...@@ -55,19 +65,26 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac ...@@ -55,19 +65,26 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac
if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) {
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq; int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int new_token_offset = batch_idx * max_ctx_len;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) { int num_elts_k = h_kv * d_k;
int num_elts_v = h_kv * d_v;
int hd = h_kv * max(d_k, d_v);
for (int i = blockIdx.y; i < new_len; i += gridDim.y) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size]; int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; dtype *new_token_id_k = new_k + (batch_idx * max_ctx_len + i) * num_elts_k;
for (int j = 0; j < h_kv * d_k; j++) { dtype *new_token_id_v = new_v + (batch_idx * max_ctx_len + i) * num_elts_v;
*(k_cache + token_idx * h_kv * d_k + j) = dtype *token_id_k =
*(new_k + (new_token_offset + i) * h_kv * d_k + j); k_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_k;
dtype *token_id_v =
v_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_v;
for (int j = threadIdx.x; j < hd; j += blockDim.x) {
if (j < num_elts_k) {
*(token_id_k + j) = *(new_token_id_k + j);
}
if (j < num_elts_v) {
*(token_id_v + j) = *(new_token_id_v + j);
} }
for (int j = 0; j < h_kv * d_v; j++) {
*(v_cache + token_idx * h_kv * d_v + j) =
*(new_v + (new_token_offset + i) * h_kv * d_v + j);
} }
} }
} }
...@@ -76,14 +93,24 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac ...@@ -76,14 +93,24 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq; int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) { int num_elts_k = h_kv * d_k;
int num_elts_v = h_kv * d_v;
int hd = h_kv * max(d_k, d_v);
for (int i = blockIdx.y; i < new_len; i += gridDim.y) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size]; int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; dtype *new_token_id_k = new_k + (i * b + batch_idx) * num_elts_k;
for (int j = 0; j < h_kv * d_k; j++) { dtype *new_token_id_v = new_v + (i * b + batch_idx) * num_elts_v;
*(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k + j); dtype *token_id_k =
k_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_k;
dtype *token_id_v =
v_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_v;
for (int j = threadIdx.x; j < hd; j += blockDim.x) {
if (j < num_elts_k) {
*(token_id_k + j) = *(new_token_id_k + j);
}
if (j < num_elts_v) {
*(token_id_v + j) = *(new_token_id_v + j);
} }
for (int j = 0; j < h_kv * d_v; j++) {
*(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (i * b + batch_idx) * h_kv * d_v + j);
} }
} }
} }
...@@ -92,16 +119,24 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac ...@@ -92,16 +119,24 @@ __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cac
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq; int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) { int num_elts_k = h_kv * d_k;
int num_elts_v = h_kv * d_v;
int hd = h_kv * max(d_k, d_v);
for (int i = blockIdx.y; i < new_len; i += gridDim.y) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size]; int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; dtype *new_token_id_k = new_k + (cu_new_lens[batch_idx] + i) * num_elts_k;
for (int j = 0; j < h_kv * d_k; j++) { dtype *new_token_id_v = new_v + (cu_new_lens[batch_idx] + i) * num_elts_v;
*(k_cache + token_idx * h_kv * d_k + j) = dtype *token_id_k =
*(new_k + (cu_new_lens[batch_idx] + i) * h_kv * d_k + j); k_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_k;
dtype *token_id_v =
v_cache + (page_idx * page_size + (cached_len - new_len + i) % page_size) * num_elts_v;
for (int j = threadIdx.x; j < hd; j += blockDim.x) {
if (j < num_elts_k) {
*(token_id_k + j) = *(new_token_id_k + j);
}
if (j < num_elts_v) {
*(token_id_v + j) = *(new_token_id_v + j);
} }
for (int j = 0; j < h_kv * d_v; j++) {
*(v_cache + token_idx * h_kv * d_v + j) =
*(new_v + (cu_new_lens[batch_idx] + i) * h_kv * d_v + j);
} }
} }
} }
...@@ -116,14 +151,15 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso ...@@ -116,14 +151,15 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso
bool is_non_paged, cudaStream_t stream) { bool is_non_paged, cudaStream_t stream) {
if (new_k.has_data() && new_v.has_data() && k_cache.has_data() && v_cache.has_data()) { if (new_k.has_data() && new_v.has_data() && k_cache.has_data() && v_cache.has_data()) {
if (is_non_paged) { if (is_non_paged) {
reindex_kv_cache_kernel<<<16, 256, 0, stream>>>( reindex_kv_cache_kernel<<<max_seq_len, block_size, 0, stream>>>(
reinterpret_cast<dtype *>(k_cache.data.dptr), reinterpret_cast<dtype *>(k_cache.data.dptr),
reinterpret_cast<dtype *>(v_cache.data.dptr), reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr), reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr), reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len); reinterpret_cast<int *>(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len);
} }
copy_to_kv_cache_kernel<<<16, 256, 0, stream>>>( dim3 grid_size(b, max_ctx_len);
copy_to_kv_cache_kernel<<<grid_size, block_size, 0, stream>>>(
reinterpret_cast<dtype *>(new_k.data.dptr), reinterpret_cast<dtype *>(new_v.data.dptr), reinterpret_cast<dtype *>(new_k.data.dptr), reinterpret_cast<dtype *>(new_v.data.dptr),
reinterpret_cast<dtype *>(k_cache.data.dptr), reinterpret_cast<dtype *>(v_cache.data.dptr), reinterpret_cast<dtype *>(k_cache.data.dptr), reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr), reinterpret_cast<int *>(page_table.data.dptr),
......
...@@ -420,6 +420,8 @@ class NonPagedKVCacheManager(KVCacheManager): ...@@ -420,6 +420,8 @@ class NonPagedKVCacheManager(KVCacheManager):
dtype=torch.int32, dtype=torch.int32,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
) )
# whether reindexing is needed, i.e. when batch seq_ids have changed
self.need_reindex = True
def allocate_memory(self, layer_number): def allocate_memory(self, layer_number):
"""Allocate memory for the cache""" """Allocate memory for the cache"""
...@@ -451,6 +453,7 @@ class NonPagedKVCacheManager(KVCacheManager): ...@@ -451,6 +453,7 @@ class NonPagedKVCacheManager(KVCacheManager):
# step() re-indexes k_cache and v_cache using batch_indices = [0, 2, 3, 1] so that # step() re-indexes k_cache and v_cache using batch_indices = [0, 2, 3, 1] so that
# they are contiguous and match the indexing in q # they are contiguous and match the indexing in q
prev_batch_size = len(self.sequences) prev_batch_size = len(self.sequences)
prev_seq_ids = set(self.sequences.keys())
unfinished_seqs = self.sequences.keys() & step_dict.keys() unfinished_seqs = self.sequences.keys() & step_dict.keys()
finished_seqs = self.sequences.keys() - unfinished_seqs finished_seqs = self.sequences.keys() - unfinished_seqs
unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs]
...@@ -478,6 +481,9 @@ class NonPagedKVCacheManager(KVCacheManager): ...@@ -478,6 +481,9 @@ class NonPagedKVCacheManager(KVCacheManager):
for i in new_seqs: for i in new_seqs:
self.sequences[i] = step_dict[i] self.sequences[i] = step_dict[i]
# Whether reindexing is needed
self.need_reindex = set(self.sequences.keys()) != prev_seq_ids
return self.sequences return self.sequences
def step( def step(
...@@ -538,7 +544,7 @@ class NonPagedKVCacheManager(KVCacheManager): ...@@ -538,7 +544,7 @@ class NonPagedKVCacheManager(KVCacheManager):
ctx_len, ctx_len,
self.max_seqlen, self.max_seqlen,
1, 1,
True, self.need_reindex,
) )
k_cache = k_cache[:batch_size] k_cache = k_cache[:batch_size]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment