Unverified Commit 233bbb8c authored by UnicornChan's avatar UnicornChan Committed by GitHub
Browse files

Merge pull request #57 from UnicornChan/develop-0.1.3

[feature] release 0.1.3
parents 67f8b370 4d1d561d
#!/usr/bin/env python
# coding=utf-8
"""
Description :
Author : Jianwei Dong
Date : 2024-08-28 10:32:05
Version : 1.0.0
LastEditors : Jianwei Dong
LastEditTime : 2024-08-28 10:32:05
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import os, sys
import time
sys.path.append(os.path.dirname(__file__) + "/../build")
import cpuinfer_ext
import torch
layer_num = 10
kv_head_num = 8
q_head_num = 32
head_dim = 128
block_len = 128
anchor_num = 1
anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC
kv_type = cpuinfer_ext.kvcache.ggml_type.FP16
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER
layer_step: int = 1
token_step: int = 1
layer_offset: int = 0
max_thread_num: int = 64
max_batch_size: int = 1
max_block_num: int = 1024
CPUInfer = cpuinfer_ext.CPUInfer(max_thread_num)
warm_up_iter = 1000
test_iter = 10000
def bench_linear(cache_seqlen: int):
with torch.inference_mode(mode=True):
cache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device="cpu")
seqlens_zero = torch.zeros((1,), dtype=torch.int32, device="cpu")
config = cpuinfer_ext.kvcache.KVCacheConfig(
layer_num,
kv_head_num,
q_head_num,
head_dim,
block_len,
anchor_num,
anchor_type,
kv_type,
retrieval_type,
layer_step,
token_step,
layer_offset,
max_block_num,
max_batch_size,
max_thread_num,
)
local_kvcache = cpuinfer_ext.kvcache.KVCache(config)
block_table = (
torch.arange(max_block_num, dtype=torch.int32, device="cpu")
.contiguous()
.view(1, -1)
)
for layer_idx in range(layer_num):
k_cache = torch.randn(
(1, cache_seqlen, kv_head_num, head_dim),
dtype=torch.float16,
device="cpu",
).contiguous()
v_cache = torch.randn(
(1, cache_seqlen, kv_head_num, head_dim),
dtype=torch.float16,
device="cpu",
).contiguous()
CPUInfer.submit(
local_kvcache.update_kvcache_fp16(
k_cache.data_ptr(),
v_cache.data_ptr(),
layer_idx,
block_table.data_ptr(),
1,
max_block_num,
seqlens_zero.data_ptr(),
cache_seqlen,
)
)
CPUInfer.sync()
input = torch.randn(
(1, 1, q_head_num, head_dim), dtype=torch.float16, device="cpu"
).contiguous()
output = torch.empty(
(1, 1, q_head_num, head_dim), dtype=torch.float16, device="cpu"
).contiguous()
# attn_lse: (bsz, q_len, q_head_num)
attn_lse = torch.empty(
(1, 1, q_head_num), dtype=torch.float32, device="cpu"
).contiguous()
input = input / 100
# warm up
for i in range(warm_up_iter):
CPUInfer.submit(
local_kvcache.attn(
input.data_ptr(),
output.data_ptr(),
attn_lse.data_ptr(),
i % layer_num,
0,
1,
1,
max_block_num,
block_table.data_ptr(),
cache_seqlens.data_ptr(),
-1,
-1,
-1,
)
)
CPUInfer.sync()
# test
start = time.perf_counter()
for i in range(test_iter):
CPUInfer.submit(
local_kvcache.attn(
input.data_ptr(),
output.data_ptr(),
attn_lse.data_ptr(),
i % layer_num,
0,
1,
1,
max_block_num,
block_table.data_ptr(),
cache_seqlens.data_ptr(),
-1,
-1,
-1,
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time = end - start
print("cache sequence length: ", cache_seqlen)
print("Time(s): ", total_time)
print("Iteration: ", test_iter)
print("Time(us) per iteration: ", total_time / test_iter * 1000000)
print(
"Bandwidth: ",
cache_seqlen
* kv_head_num
* head_dim
* 2
* 2
* test_iter
/ total_time
/ 1000
/ 1000
/ 1000,
"GB/s",
)
print("")
bench_linear(1024)
bench_linear(4096)
bench_linear(16384)
bench_linear(32768)
bench_linear(65536)
#!/usr/bin/env python
# coding=utf-8
"""
Description :
Author : Jianwei Dong
Date : 2024-08-28 10:32:05
Version : 1.0.0
LastEditors : Jianwei Dong
LastEditTime : 2024-08-28 10:32:05
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import os, sys
import time
sys.path.append(os.path.dirname(__file__) + "/../build")
import cpuinfer_ext
import torch
layer_num = 10
kv_head_num = 8
q_head_num = 32
head_dim = 128
block_len = 128
anchor_num = 1
warm_up_iter = 1000
test_iter = 10000
def bench_linear(cache_seqlen: int, device):
with torch.inference_mode(mode=True):
kvcaches = []
for layer_idx in range(layer_num):
k_cache = torch.randn(
(1, 32, cache_seqlen, head_dim),
dtype=torch.float16,
device=device,
).contiguous()
v_cache = torch.randn(
(1, 32, cache_seqlen, head_dim),
dtype=torch.float16,
device=device,
).contiguous()
kvcaches.append((k_cache, v_cache))
input = torch.randn(
(1, q_head_num, 1, head_dim), dtype=torch.float16, device=device
).contiguous()
input = input / 100
# warm up
for i in range(warm_up_iter):
k_cache = kvcaches[i % layer_num][0]
v_cache = kvcaches[i % layer_num][1]
torch.nn.functional.scaled_dot_product_attention(input, k_cache, v_cache)
# test
start = time.perf_counter()
for i in range(test_iter):
k_cache = kvcaches[i % layer_num][0]
v_cache = kvcaches[i % layer_num][1]
torch.nn.functional.scaled_dot_product_attention(input, k_cache, v_cache)
end = time.perf_counter()
total_time = end - start
print("cache sequence length: ", cache_seqlen)
print("Time(s): ", total_time)
print("Iteration: ", test_iter)
print("Time(us) per iteration: ", total_time / test_iter * 1000000)
print(
"Bandwidth: ",
cache_seqlen
* q_head_num
* head_dim
* 2
* 2
* test_iter
/ total_time
/ 1000
/ 1000
/ 1000,
"GB/s",
)
print("")
bench_linear(1024, "cpu")
bench_linear(4096, "cpu")
bench_linear(1024, "cuda")
bench_linear(4096, "cuda")
bench_linear(16384, "cuda")
bench_linear(32768, "cuda")
bench_linear(65536, "cuda")
......@@ -3,93 +3,125 @@
* @Author : chenht2022
* @Date : 2024-07-22 02:03:05
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:33:34
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "backend.h"
Backend::Backend(int thread_num) {
thread_num_ = thread_num;
thread_state_.resize(thread_num);
for (int i = 0; i < thread_num; i++) {
thread_local int Backend::thread_local_id = -1;
Backend::Backend(int max_thread_num) {
max_thread_num_ = max_thread_num;
thread_state_.resize(max_thread_num_);
for (int i = 0; i < max_thread_num_; i++) {
thread_state_[i].curr = std::make_unique<std::atomic<int>>();
thread_state_[i].status = std::make_unique<std::atomic<ThreadStatus>>(ThreadStatus::WAITING);
thread_state_[i].status =
std::make_unique<std::atomic<ThreadStatus>>(ThreadStatus::WAITING);
}
workers_.resize(thread_num);
for (int i = 1; i < thread_num; i++) {
workers_.resize(max_thread_num_);
for (int i = 1; i < max_thread_num_; i++) {
workers_[i] = std::thread(&Backend::worker_thread, this, i);
}
}
Backend::~Backend() {
for (int i = 0; i < thread_num_; i++) {
thread_state_[i].status->store(ThreadStatus::EXIT, std::memory_order_release);
for (int i = 0; i < max_thread_num_; i++) {
thread_state_[i].status->store(ThreadStatus::EXIT,
std::memory_order_release);
}
for (int i = 1; i < thread_num_; i++) {
for (int i = 1; i < max_thread_num_; i++) {
if (workers_[i].joinable()) {
workers_[i].join();
}
}
}
int Backend::get_thread_num() {
return thread_num_;
}
int Backend::get_thread_num() { return max_thread_num_; }
void Backend::do_work_stealing_job(int task_num, std::function<void(int)> func) {
func_ = func;
void Backend::do_work_stealing_job(int task_num,
std::function<void(int)> init_func,
std::function<void(int)> compute_func,
std::function<void(int)> finalize_func) {
init_func_ = init_func;
compute_func_ = compute_func;
finalize_func_ = finalize_func;
thread_num_ = std::min(max_thread_num_, task_num);
int base = task_num / thread_num_;
int remain = task_num % thread_num_;
thread_state_[0].end = base + (0 < remain);
// 为主线程设置 thread_local_id
thread_local_id = 0;
for (int i = 1; i < thread_num_; i++) {
thread_state_[i].curr->store(thread_state_[i - 1].end, std::memory_order_relaxed);
thread_state_[i].curr->store(thread_state_[i - 1].end,
std::memory_order_relaxed);
thread_state_[i].end = thread_state_[i - 1].end + base + (i < remain);
thread_state_[i].status->store(ThreadStatus::WORKING, std::memory_order_release);
thread_state_[i].status->store(ThreadStatus::WORKING,
std::memory_order_release);
}
thread_state_[0].curr->store(0, std::memory_order_relaxed);
thread_state_[0].status->store(ThreadStatus::WORKING, std::memory_order_release);
thread_state_[0].status->store(ThreadStatus::WORKING,
std::memory_order_release);
process_tasks(0);
for (int i = 1; i < thread_num_; i++) {
while (thread_state_[i].status->load(std::memory_order_acquire) == ThreadStatus::WORKING) {
while (thread_state_[i].status->load(std::memory_order_acquire) ==
ThreadStatus::WORKING) {
}
}
}
void Backend::process_tasks(int thread_id) {
if (init_func_ != nullptr) {
init_func_(thread_id);
}
while (true) {
int task_id = thread_state_[thread_id].curr->fetch_add(1, std::memory_order_acq_rel);
int task_id = thread_state_[thread_id].curr->fetch_add(
1, std::memory_order_acq_rel);
if (task_id >= thread_state_[thread_id].end) {
break;
}
func_(task_id);
compute_func_(task_id);
}
for (int t_offset = 1; t_offset < thread_num_; t_offset++) {
int t_i = (thread_id + t_offset) % thread_num_;
if (thread_state_[t_i].status->load(std::memory_order_acquire) != ThreadStatus::WORKING) {
if (thread_state_[t_i].status->load(std::memory_order_acquire) !=
ThreadStatus::WORKING) {
continue;
}
while (true) {
int task_id = thread_state_[t_i].curr->fetch_add(1, std::memory_order_acq_rel);
int task_id = thread_state_[t_i].curr->fetch_add(
1, std::memory_order_acq_rel);
if (task_id >= thread_state_[t_i].end) {
break;
}
func_(task_id);
compute_func_(task_id);
}
}
thread_state_[thread_id].status->store(ThreadStatus::WAITING, std::memory_order_release);
if (finalize_func_ != nullptr) {
finalize_func_(thread_id);
}
thread_state_[thread_id].status->store(ThreadStatus::WAITING,
std::memory_order_release);
}
void Backend::worker_thread(int thread_id) {
auto start = std::chrono::steady_clock::now();
thread_local_id = thread_id; // 设置线程本地变量
while (true) {
ThreadStatus status = thread_state_[thread_id].status->load(std::memory_order_acquire);
ThreadStatus status =
thread_state_[thread_id].status->load(std::memory_order_acquire);
if (status == ThreadStatus::WORKING) {
process_tasks(thread_id);
start = std::chrono::steady_clock::now();
} else if (status == ThreadStatus::WAITING) {
auto now = std::chrono::steady_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(now - start).count();
auto duration =
std::chrono::duration_cast<std::chrono::milliseconds>(now -
start)
.count();
if (duration > 50) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
......
......@@ -3,7 +3,7 @@
* @Author : chenht2022
* @Date : 2024-07-22 02:03:05
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:33:38
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
......@@ -31,20 +31,25 @@ struct ThreadState {
};
class Backend {
public:
public:
Backend(int);
~Backend();
int get_thread_num();
void do_work_stealing_job(int, std::function<void(int)>);
void do_work_stealing_job(int, std::function<void(int)>,
std::function<void(int)>,
std::function<void(int)>);
static thread_local int thread_local_id;
private:
private:
int thread_num_;
std::vector<ThreadState> thread_state_; // [thread_num]
std::function<void(int)> func_;
int max_thread_num_;
std::vector<ThreadState> thread_state_; // [thread_num]
std::function<void(int)> init_func_;
std::function<void(int)> compute_func_;
std::function<void(int)> finalize_func_;
std::vector<std::thread> workers_;
void process_tasks(int);
void worker_thread(int);
};
#endif
\ No newline at end of file
......@@ -54,4 +54,4 @@ void TaskQueue::processTasks() {
}
mutex.unlock();
}
}
}
\ No newline at end of file
......@@ -4,7 +4,7 @@
* @Date : 2024-07-16 10:43:18
* @Version : 1.0.0
* @LastEditors : chenxl
* @LastEditTime : 2024-08-12 12:28:25
* @LastEditTime : 2024-08-08 04:23:51
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_TASKQUEUE_H
......
#!/usr/bin/env python
# coding=utf-8
"""
Description :
Author : Jianwei Dong
Date : 2024-08-28 10:32:05
Version : 1.0.0
LastEditors : chenht2022
LastEditTime : 2024-08-28 10:32:05
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import os, sys
import time
sys.path.append(os.path.dirname(__file__) + "/../build")
import cpuinfer_ext
from flash_attn import flash_attn_with_kvcache
import torch
layer_num = 10
kv_head_num = 8
q_head_num = 32
head_dim = 128
block_len = 128
anchor_num = 1
cache_seqlen = 8192
cache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device="cpu")
seqlens_zero = torch.zeros((1,), dtype=torch.int32, device="cpu")
anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC
kv_type = cpuinfer_ext.kvcache.ggml_type.FP16
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER
layer_step: int = 1
token_step: int = 1
layer_offset: int = 0
max_thread_num: int = 2
max_batch_size: int = 1
max_block_num: int = 512
CPUInfer = cpuinfer_ext.CPUInfer(max_thread_num)
validation_iter = 100
with torch.inference_mode(mode=True):
config = cpuinfer_ext.kvcache.KVCacheConfig(
layer_num,
kv_head_num,
q_head_num,
head_dim,
block_len,
anchor_num,
anchor_type,
kv_type,
retrieval_type,
layer_step,
token_step,
layer_offset,
max_block_num,
max_batch_size,
max_thread_num,
)
local_kvcache = cpuinfer_ext.kvcache.KVCache(config)
kvcaches = []
block_table = (
torch.arange(max_block_num, dtype=torch.int32, device="cpu")
.contiguous()
.view(1, -1)
)
for layer_idx in range(layer_num):
k_cache = torch.randn(
(1, cache_seqlen, kv_head_num, head_dim), dtype=torch.float16, device="cpu"
).contiguous()
v_cache = torch.randn(
(1, cache_seqlen, kv_head_num, head_dim), dtype=torch.float16, device="cpu"
).contiguous()
CPUInfer.submit(
local_kvcache.update_kvcache_fp16(
k_cache.data_ptr(),
v_cache.data_ptr(),
layer_idx,
block_table.data_ptr(),
1,
max_block_num,
seqlens_zero.data_ptr(),
cache_seqlen,
)
)
CPUInfer.sync()
kvcaches.append((k_cache.to("cuda"), v_cache.to("cuda")))
# validation
for i in range(validation_iter):
k_cache = kvcaches[i % layer_num][0]
v_cache = kvcaches[i % layer_num][1]
input = torch.randn(
(1, 1, q_head_num, head_dim), dtype=torch.float16, device="cpu"
).contiguous()
output = torch.empty(
(1, 1, q_head_num, head_dim), dtype=torch.float16, device="cpu"
).contiguous()
# attn_lse: (bsz, q_len, q_head_num)
attn_lse = torch.empty(
(1, 1, q_head_num), dtype=torch.float32, device="cpu"
).contiguous()
input = input / 100
CPUInfer.submit(
local_kvcache.attn(
input.data_ptr(),
output.data_ptr(),
attn_lse.data_ptr(),
i % layer_num,
0,
1,
1,
max_block_num,
block_table.data_ptr(),
cache_seqlens.data_ptr(),
-1,
-1,
-1,
)
)
CPUInfer.sync()
# print("cpuinfer output", output)
t_output = flash_attn_with_kvcache(
q=input.to("cuda"),
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens.to("cuda"),
)
# print("torch output", t_output)
diff = torch.mean(torch.abs(output.to("cuda") - t_output)) / torch.mean(
torch.abs(t_output)
)
print("diff = ", diff)
assert diff < 0.001
/**
* @Description :
* @Author : chenht2022
* @Author : chenht2022, Jianwei Dong
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-08-07 10:39:37
* @LastEditors : Jianwei Dong
* @LastEditTime : 2024-08-26 22:47:06
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
// Python bindings
#include <cstdint>
#include <iostream>
#include <memory>
#include "cpu_backend/cpuinfer.h"
#include "device_launch_parameters.h"
#include "llamafile/flags.h"
#include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h"
#include "operators/llamafile/mlp.h"
#include "operators/llamafile/moe.h"
......@@ -21,119 +19,541 @@
#include "pybind11/operators.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include <cstdint>
#include <iostream>
#include <memory>
namespace py = pybind11;
using namespace pybind11::literals;
// Binding functions for the KVCache class
class KVCacheBindings {
public:
class AttnBindings {
public:
struct Args {
CPUInfer *cpuinfer;
KVCache *kv_cache;
const ggml_fp16_t *q_in;
ggml_fp16_t *output;
float *attn_lse;
int layer_idx;
int generate_token_idx;
int q_len;
int batch_size;
int max_block_num;
int *block_table;
int *cache_seqlens;
int pick_block_num;
int init_block_num;
int local_block_num;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(
&KVCache::attn, args_->kv_cache, args_->q_in, args_->output,
args_->attn_lse, args_->layer_idx, args_->generate_token_idx,
args_->q_len, args_->batch_size, args_->max_block_num,
args_->block_table, args_->cache_seqlens, args_->pick_block_num,
args_->init_block_num, args_->local_block_num);
}
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(KVCache &kv_cache, intptr_t q_in, intptr_t output,
intptr_t attn_lse, int layer_idx,
int generate_token_idx, int q_len, int batch_size,
int max_block_num, intptr_t block_table,
intptr_t cache_seqlens, int pick_block_num,
int init_block_num, int local_block_num) {
Args *args = new Args{nullptr,
&kv_cache,
(const ggml_fp16_t *)q_in,
(ggml_fp16_t *)output,
(float *)attn_lse,
layer_idx,
generate_token_idx,
q_len,
batch_size,
max_block_num,
(int *)block_table,
(int *)cache_seqlens,
pick_block_num,
init_block_num,
local_block_num};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class GetAllKVCacheOneLayerBindings {
public:
struct Args {
CPUInfer *cpuinfer;
KVCache *kv_cache;
int layer_id;
ggml_fp16_t *k_in;
ggml_fp16_t *v_in;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&KVCache::get_all_kvcache_one_layer,
args_->kv_cache, args_->layer_id,
args_->k_in, args_->v_in);
}
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in,
int layer_id) {
Args *args = new Args{nullptr, &kv_cache, layer_id,
(ggml_fp16_t *)k_in, (ggml_fp16_t *)v_in};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class GetAndUpdateKVCacheFp16Bindings {
public:
struct Args {
CPUInfer *cpuinfer;
KVCache *kv_cache;
ggml_fp16_t *k_in;
ggml_fp16_t *v_in;
int layer_id;
int *block_table;
int batch_size;
int max_block_num;
int *cache_seqlens;
int q_len;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&KVCache::get_and_update_kvcache_fp16,
args_->kv_cache, args_->k_in, args_->v_in,
args_->layer_id, args_->block_table,
args_->batch_size, args_->max_block_num,
args_->cache_seqlens, args_->q_len);
}
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in,
int layer_id, intptr_t block_table, int batch_size,
int max_block_num, intptr_t cache_seqlens,
int q_len) {
Args *args = new Args{nullptr,
&kv_cache,
(ggml_fp16_t *)k_in,
(ggml_fp16_t *)v_in,
layer_id,
(int *)block_table,
batch_size,
max_block_num,
(int *)cache_seqlens,
q_len};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class GetKVCacheFp16Bindings {
public:
struct Args {
CPUInfer *cpuinfer;
KVCache *kv_cache;
ggml_fp16_t *k_in;
ggml_fp16_t *v_in;
int layer_id;
int *block_table;
int batch_size;
int max_block_num;
int *cache_seqlens;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(
&KVCache::get_kvcache_fp16, args_->kv_cache, args_->k_in,
args_->v_in, args_->layer_id, args_->block_table,
args_->batch_size, args_->max_block_num, args_->cache_seqlens);
}
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in,
int layer_id, intptr_t block_table, int batch_size,
int max_block_num, intptr_t cache_seqlens) {
Args *args = new Args{nullptr,
&kv_cache,
(ggml_fp16_t *)k_in,
(ggml_fp16_t *)v_in,
layer_id,
(int *)block_table,
batch_size,
max_block_num,
(int *)cache_seqlens};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class UpdateKVCacheFp16Bindings {
public:
struct Args {
CPUInfer *cpuinfer;
KVCache *kv_cache;
ggml_fp16_t *k_in;
ggml_fp16_t *v_in;
int layer_id;
int *block_table;
int batch_size;
int max_block_num;
int *cache_seqlens;
int q_len;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&KVCache::update_kvcache_fp16,
args_->kv_cache, args_->k_in, args_->v_in,
args_->layer_id, args_->block_table,
args_->batch_size, args_->max_block_num,
args_->cache_seqlens, args_->q_len);
}
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in,
int layer_id, intptr_t block_table, int batch_size,
int max_block_num, intptr_t cache_seqlens,
int q_len) {
Args *args = new Args{nullptr,
&kv_cache,
(ggml_fp16_t *)k_in,
(ggml_fp16_t *)v_in,
layer_id,
(int *)block_table,
batch_size,
max_block_num,
(int *)cache_seqlens,
q_len};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class UpdateImportanceBindings {
public:
struct Args {
CPUInfer *cpuinfer;
KVCache *kv_cache;
const ggml_fp16_t *importance;
int layer_id;
int *block_table;
int batch_size;
int max_block_num;
int *offset;
int width;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(
&KVCache::update_importance, args_->kv_cache, args_->importance,
args_->layer_id, args_->block_table, args_->batch_size,
args_->max_block_num, args_->offset, args_->width);
}
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(KVCache &kv_cache, intptr_t importance, int layer_id,
intptr_t block_table, int batch_size,
int max_block_num, intptr_t offset, int width) {
Args *args = new Args{nullptr,
&kv_cache,
(const ggml_fp16_t *)importance,
layer_id,
(int *)block_table,
batch_size,
max_block_num,
(int *)offset,
width};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class AttnWithKVCacheBindings {
public:
struct Args {
CPUInfer *cpuinfer;
KVCache *kv_cache;
const ggml_fp16_t *q_in;
const ggml_fp16_t *k_in;
const ggml_fp16_t *v_in;
ggml_fp16_t *output;
float *attn_lse;
int layer_idx;
int generate_token_idx;
int q_len;
int batch_size;
int max_block_num;
int *block_table;
int *cache_seqlens;
int topk;
int local;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(
&KVCache::attn_with_kvcache, args_->kv_cache, args_->q_in,
args_->k_in, args_->v_in, args_->output, args_->attn_lse,
args_->layer_idx, args_->generate_token_idx, args_->q_len,
args_->batch_size, args_->max_block_num, args_->block_table,
args_->cache_seqlens, args_->topk, args_->local);
}
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(KVCache &kv_cache, intptr_t q_in, intptr_t k_in,
intptr_t v_in, intptr_t output, intptr_t attn_lse,
int layer_idx, int generate_token_idx, int q_len,
int batch_size, int max_block_num,
intptr_t block_table, intptr_t cache_seqlens,
int topk, int local) {
Args *args = new Args{nullptr,
&kv_cache,
(const ggml_fp16_t *)q_in,
(const ggml_fp16_t *)k_in,
(const ggml_fp16_t *)v_in,
(ggml_fp16_t *)output,
(float *)attn_lse,
layer_idx,
generate_token_idx,
q_len,
batch_size,
max_block_num,
(int *)block_table,
(int *)cache_seqlens,
topk,
local};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class ClearImportanceAllLayersBindings {
public:
struct Args {
CPUInfer *cpuinfer;
KVCache *kv_cache;
int *block_table;
int *cache_seqlens;
int batch_size;
int max_block_num;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&KVCache::clear_importance_all_layers,
args_->kv_cache, args_->block_table,
args_->cache_seqlens, args_->batch_size,
args_->max_block_num);
}
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(KVCache &kv_cache, intptr_t block_table,
intptr_t cache_seqlens, int batch_size,
int max_block_num) {
Args *args = new Args{nullptr,
&kv_cache,
(int *)block_table,
(int *)cache_seqlens,
batch_size,
max_block_num};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class CalcAnchorAllLayersBindinds {
public:
struct Args {
CPUInfer *cpuinfer;
KVCache *kv_cache;
int *block_table;
int *cache_seqlens;
int batch_size;
int max_block_num;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&KVCache::calc_anchor_all_layers,
args_->kv_cache, args_->block_table,
args_->cache_seqlens, args_->batch_size,
args_->max_block_num);
}
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(KVCache &kv_cache, intptr_t block_table,
intptr_t cache_seqlens, int batch_size,
int max_block_num) {
Args *args = new Args{nullptr,
&kv_cache,
(int *)block_table,
(int *)cache_seqlens,
batch_size,
max_block_num};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class LoadKVCacheBindings {
public:
struct Args {
CPUInfer *cpuinfer;
KVCache *kv_cache;
std::string tensor_file_path;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&KVCache::load_kvcache, args_->kv_cache,
args_->tensor_file_path);
}
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(KVCache &kv_cache, std::string tensor_file_path) {
Args *args =
new Args{nullptr, &kv_cache, (std::string)tensor_file_path};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class DumpKVCacheBindings {
public:
struct Args {
CPUInfer *cpuinfer;
KVCache *kv_cache;
int *block_table;
int cache_total_len;
std::string tensor_file_path;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&KVCache::dump_kvcache, args_->kv_cache,
args_->block_table, args_->cache_total_len,
args_->tensor_file_path);
}
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(KVCache &kv_cache, intptr_t block_table,
int cache_total_len, std::string tensor_file_path) {
Args *args =
new Args{nullptr, &kv_cache, (int *)block_table,
cache_total_len, (std::string)tensor_file_path};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
};
class LinearBindings {
public:
public:
class WarmUpBindinds {
public:
public:
struct Args {
CPUInfer* cpuinfer;
Linear* linear;
CPUInfer *cpuinfer;
Linear *linear;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&Linear::warm_up, args_->linear);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(Linear& linear) {
Args* args = new Args{nullptr, &linear};
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(Linear &linear) {
Args *args = new Args{nullptr, &linear};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class ForwardBindings {
public:
public:
struct Args {
CPUInfer* cpuinfer;
Linear* linear;
CPUInfer *cpuinfer;
Linear *linear;
int qlen;
const void* input;
void* output;
const void *input;
void *output;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&Linear::forward, args_->linear, args_->qlen, args_->input, args_->output);
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&Linear::forward, args_->linear,
args_->qlen, args_->input, args_->output);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(Linear& linear, int qlen, intptr_t input, intptr_t output) {
Args* args = new Args{nullptr, &linear, qlen, (const void*)input, (void*)output};
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(Linear &linear, int qlen, intptr_t input,
intptr_t output) {
Args *args = new Args{nullptr, &linear, qlen, (const void *)input,
(void *)output};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
};
class MLPBindings {
public:
public:
class WarmUpBindinds {
public:
public:
struct Args {
CPUInfer* cpuinfer;
MLP* mlp;
CPUInfer *cpuinfer;
MLP *mlp;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&MLP::warm_up, args_->mlp);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(MLP& mlp) {
Args* args = new Args{nullptr, &mlp};
static std::pair<intptr_t, intptr_t> cpuinfer_interface(MLP &mlp) {
Args *args = new Args{nullptr, &mlp};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class ForwardBindings {
public:
public:
struct Args {
CPUInfer* cpuinfer;
MLP* mlp;
CPUInfer *cpuinfer;
MLP *mlp;
int qlen;
const void* input;
void* output;
const void *input;
void *output;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MLP::forward, args_->mlp, args_->qlen, args_->input, args_->output);
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&MLP::forward, args_->mlp, args_->qlen,
args_->input, args_->output);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(MLP& mlp, int qlen, intptr_t input, intptr_t output) {
Args* args = new Args{nullptr, &mlp, qlen, (const void*)input, (void*)output};
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(MLP &mlp, int qlen, intptr_t input,
intptr_t output) {
Args *args = new Args{nullptr, &mlp, qlen, (const void *)input,
(void *)output};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
};
class MOEBindings {
public:
public:
class WarmUpBindinds {
public:
public:
struct Args {
CPUInfer* cpuinfer;
MOE* moe;
CPUInfer *cpuinfer;
MOE *moe;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&MOE::warm_up, args_->moe);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(MOE& moe) {
Args* args = new Args{nullptr, &moe};
static std::pair<intptr_t, intptr_t> cpuinfer_interface(MOE &moe) {
Args *args = new Args{nullptr, &moe};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class ForwardBindings {
public:
public:
struct Args {
CPUInfer* cpuinfer;
MOE* moe;
CPUInfer *cpuinfer;
MOE *moe;
int qlen;
int k;
const uint64_t* expert_ids;
const float* weights;
const void* input;
void* output;
const uint64_t *expert_ids;
const float *weights;
const void *input;
void *output;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MOE::forward, args_->moe, args_->qlen, args_->k, args_->expert_ids, args_->weights, args_->input, args_->output);
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(
&MOE::forward, args_->moe, args_->qlen, args_->k,
args_->expert_ids, args_->weights, args_->input, args_->output);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(MOE& moe, int qlen, int k, intptr_t expert_ids, intptr_t weights, intptr_t input, intptr_t output) {
Args* args = new Args{nullptr, &moe, qlen, k, (const uint64_t*)expert_ids, (const float*)weights, (const void*)input, (void*)output};
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(MOE &moe, int qlen, int k, intptr_t expert_ids,
intptr_t weights, intptr_t input, intptr_t output) {
Args *args = new Args{nullptr,
&moe,
qlen,
k,
(const uint64_t *)expert_ids,
(const float *)weights,
(const void *)input,
(void *)output};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
......@@ -149,8 +569,12 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
auto linear_module = m.def_submodule("linear");
py::class_<LinearConfig>(linear_module, "LinearConfig")
.def(py::init([](int hidden_size, int intermediate_size, int stride, int group_max_len, intptr_t proj, int proj_type, int hidden_type) {
return LinearConfig(hidden_size, intermediate_size, stride, group_max_len, (void*)proj, (ggml_type)proj_type, (ggml_type)hidden_type);
.def(py::init([](int hidden_size, int intermediate_size, int stride,
int group_max_len, intptr_t proj, int proj_type,
int hidden_type) {
return LinearConfig(hidden_size, intermediate_size, stride,
group_max_len, (void *)proj,
(ggml_type)proj_type, (ggml_type)hidden_type);
}));
py::class_<Linear>(linear_module, "Linear")
.def(py::init<LinearConfig>())
......@@ -159,8 +583,15 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
auto mlp_module = m.def_submodule("mlp");
py::class_<MLPConfig>(mlp_module, "MLPConfig")
.def(py::init([](int hidden_size, int intermediate_size, int stride, int group_max_len, intptr_t gate_proj, intptr_t up_proj, intptr_t down_proj, int gate_type, int up_type, int down_type, int hidden_type) {
return MLPConfig(hidden_size, intermediate_size, stride, group_max_len, (void*)gate_proj, (void*)up_proj, (void*)down_proj, (ggml_type)gate_type, (ggml_type)up_type, (ggml_type)down_type, (ggml_type)hidden_type);
.def(py::init([](int hidden_size, int intermediate_size, int stride,
int group_max_len, intptr_t gate_proj,
intptr_t up_proj, intptr_t down_proj, int gate_type,
int up_type, int down_type, int hidden_type) {
return MLPConfig(hidden_size, intermediate_size, stride,
group_max_len, (void *)gate_proj, (void *)up_proj,
(void *)down_proj, (ggml_type)gate_type,
(ggml_type)up_type, (ggml_type)down_type,
(ggml_type)hidden_type);
}));
py::class_<MLP>(mlp_module, "MLP")
.def(py::init<MLPConfig>())
......@@ -169,11 +600,84 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
auto moe_module = m.def_submodule("moe");
py::class_<MOEConfig>(moe_module, "MOEConfig")
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, int group_max_len, intptr_t gate_proj, intptr_t up_proj, intptr_t down_proj, int gate_type, int up_type, int down_type, int hidden_type) {
return MOEConfig(expert_num, routed_expert_num, hidden_size, intermediate_size, stride, group_min_len, group_max_len, (void*)gate_proj, (void*)up_proj, (void*)down_proj, (ggml_type)gate_type, (ggml_type)up_type, (ggml_type)down_type, (ggml_type)hidden_type);
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size,
int intermediate_size, int stride, int group_min_len,
int group_max_len, intptr_t gate_proj,
intptr_t up_proj, intptr_t down_proj, int gate_type,
int up_type, int down_type, int hidden_type) {
return MOEConfig(expert_num, routed_expert_num, hidden_size,
intermediate_size, stride, group_min_len,
group_max_len, (void *)gate_proj, (void *)up_proj,
(void *)down_proj, (ggml_type)gate_type,
(ggml_type)up_type, (ggml_type)down_type,
(ggml_type)hidden_type);
}));
py::class_<MOE>(moe_module, "MOE")
.def(py::init<MOEConfig>())
.def("warm_up", &MOEBindings::WarmUpBindinds::cpuinfer_interface)
.def("forward", &MOEBindings::ForwardBindings::cpuinfer_interface);
auto kvcache_module = m.def_submodule("kvcache");
py::enum_<AnchorType>(kvcache_module, "AnchorType")
.value("FIXED", AnchorType::FIXED_ANCHOR)
.value("DYNAMIC", AnchorType::DYNAMIC)
.value("QUEST", AnchorType::QUEST)
.value("BLOCK_MAX", AnchorType::BLOCK_MAX)
.value("BLOCK_MEAN", AnchorType::BLOCK_MEAN);
py::enum_<ggml_type>(kvcache_module, "ggml_type")
.value("FP16", ggml_type::GGML_TYPE_F16)
.value("FP32", ggml_type::GGML_TYPE_F32)
.value("Q4_0", ggml_type::GGML_TYPE_Q4_0)
.value("Q8_0", ggml_type::GGML_TYPE_Q8_0);
py::enum_<RetrievalType>(kvcache_module, "RetrievalType")
.value("LAYER", RetrievalType::LAYER)
.value("KVHEAD", RetrievalType::KVHEAD)
.value("QHEAD", RetrievalType::QHEAD);
py::class_<KVCacheConfig>(kvcache_module, "KVCacheConfig")
.def(py::init<int, int, int, int, int, int, AnchorType, ggml_type,
RetrievalType, int, int, int, int, int, int>())
.def_readwrite("layer_num", &KVCacheConfig::layer_num)
.def_readwrite("kv_head_num", &KVCacheConfig::kv_head_num)
.def_readwrite("q_head_num", &KVCacheConfig::q_head_num)
.def_readwrite("head_dim", &KVCacheConfig::head_dim)
.def_readwrite("block_len", &KVCacheConfig::block_len)
.def_readwrite("anchor_num", &KVCacheConfig::anchor_num)
.def_readwrite("anchor_type", &KVCacheConfig::anchor_type)
.def_readwrite("kv_type", &KVCacheConfig::kv_type)
.def_readwrite("retrieval_type", &KVCacheConfig::retrieval_type)
.def_readwrite("layer_step", &KVCacheConfig::layer_step)
.def_readwrite("token_step", &KVCacheConfig::token_step)
.def_readwrite("layer_offset", &KVCacheConfig::layer_offset)
.def_readwrite("max_block_num", &KVCacheConfig::max_block_num)
.def_readwrite("max_batch_size", &KVCacheConfig::max_batch_size)
.def_readwrite("max_thread_num", &KVCacheConfig::max_thread_num);
py::class_<KVCache>(kvcache_module, "KVCache")
.def(py::init<KVCacheConfig>())
.def("get_cache_total_len", &KVCache::get_cache_total_len)
.def("update_cache_total_len",
[](KVCache &kvcache, int cache_total_len) {
kvcache.update_cache_total_len(cache_total_len);
})
.def("attn", &KVCacheBindings::AttnBindings::cpuinfer_interface)
.def(
"get_all_kvcache_one_layer",
&KVCacheBindings::GetAllKVCacheOneLayerBindings::cpuinfer_interface)
.def("get_and_update_kvcache_fp16",
&KVCacheBindings::GetAndUpdateKVCacheFp16Bindings::
cpuinfer_interface)
.def("get_kvcache_fp16",
&KVCacheBindings::GetKVCacheFp16Bindings::cpuinfer_interface)
.def("update_kvcache_fp16",
&KVCacheBindings::UpdateKVCacheFp16Bindings::cpuinfer_interface)
.def("update_importance",
&KVCacheBindings::UpdateImportanceBindings::cpuinfer_interface)
.def("attn_with_kvcache",
&KVCacheBindings::AttnWithKVCacheBindings::cpuinfer_interface)
.def("clear_importance_all_layers",
&KVCacheBindings::ClearImportanceAllLayersBindings::
cpuinfer_interface)
.def("calc_anchor_all_layers",
&KVCacheBindings::CalcAnchorAllLayersBindinds::cpuinfer_interface);
}
/**
* @Description :
* @Author : Jianwei Dong
* @Date : 2024-08-26 22:47:06
* @Version : 1.0.0
* @LastEditors : Jianwei Dong
* @LastEditTime : 2024-08-26 22:47:06
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_OPERATOR_KVCACHE_H
#define CPUINFER_OPERATOR_KVCACHE_H
#include <algorithm>
#include <atomic>
#include <cassert>
#include <condition_variable>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <functional>
#include <future>
#include <iostream>
#include <memory>
#include <mutex>
#include <queue>
#include <random>
#include <stdexcept>
#include <thread>
#include <vector>
#include "../../cpu_backend/backend.h"
#include "llama.cpp/ggml-common.h"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
#define CHUNK_SIZE 32
/**
* @brief Converts a ggml_type enum value to its corresponding string
* representation.
*
* This function provides a human-readable string representation for a given
* ggml_type enum value. The string can be used for logging, debugging, or
* displaying information in a user interface.
*
* @param type The ggml_type enum value to convert.
* @return A string representation of the enum value.
*/
std::string ggml_type_to_string(ggml_type type);
/**
* @enum AnchorType
* @brief Defines the types of anchors used in attention mechanisms.
*
* This enum specifies different types of anchors that can be used in attention
* mechanisms, such as fixed anchors, dynamic anchors, or special anchors like
* QUEST, BLOCK_MEAN, or BLOCK_MAX.
*/
enum AnchorType {
FIXED_ANCHOR, /**< A fixed anchor that does not change. */
DYNAMIC, /**< A dynamic anchor that can change over time. */
QUEST, /**< A special anchor type used for QUEST (Query and Embedding Space
Transformation). */
BLOCK_MEAN, /**< An anchor based on the mean of a block of data. */
BLOCK_MAX /**< An anchor based on the maximum value within a block of data.
*/
};
/**
* @brief Converts an AnchorType enum value to its corresponding string
* representation.
*
* This function provides a human-readable string representation for a given
* AnchorType enum value. The string can be used for logging, debugging, or
* displaying information in a user interface.
*
* @param anchor_type The AnchorType enum value to convert.
* @return A string representation of the enum value.
*/
std::string AnchorTypeToString(AnchorType anchor_type);
/**
* @enum RetrievalType
* @brief Defines the types of retrieval strategies in attention mechanisms.
*
* This enum specifies different retrieval strategies that can be used in
* attention mechanisms, such as layer-level retrieval, key-value head-level
* retrieval, or query head-level retrieval.
*/
enum RetrievalType {
LAYER, /**< Retrieval at the layer level. */
KVHEAD, /**< Retrieval at the key-value head level. */
QHEAD /**< Retrieval at the query head level. */
};
/**
* @brief Converts a RetrievalType enum value to its corresponding string
* representation.
*
* This function provides a human-readable string representation for a given
* RetrievalType enum value. The string can be used for logging, debugging, or
* displaying information in a user interface.
*
* @param retrieval_type The RetrievalType enum value to convert.
* @return A string representation of the enum value.
*/
std::string RetrievalTypeToString(RetrievalType retrieval_type);
/**
* @struct KVCacheConfig
* @brief Configuration structure for Key-Value (KV) Cache.
*
* This structure holds configuration parameters for setting up and managing
* a Key-Value (KV) Cache used in various attention mechanisms. It includes
* parameters such as the number of layers, the number of heads, the dimension
* of each head, block length, anchor information, and memory-related settings.
*/
struct KVCacheConfig {
int layer_num; /**< Number of layers in the model. */
int kv_head_num; /**< Number of heads in the KV Cache. */
int q_head_num; /**< Number of heads in the query. */
int head_dim; /**< Dimension of each head. */
int block_len; /**< Length of each block in the cache. */
int anchor_num; /**< Number of anchors used in attention. */
ggml_type kv_type; /**< Data type of the KV Cache (e.g., fp16, q8_0). */
// Controls the pre-allocated memory size
int max_block_num; /**< Maximum number of blocks that can be allocated. */
int max_batch_size; /**< Maximum batch size that can be processed. */
int max_thread_num; /**< Maximum number of threads that can be used. */
AnchorType
anchor_type; /**< Type of anchors used in the attention mechanism. */
RetrievalType
retrieval_type; /**< Type of retrieval strategy used in the cache. */
int layer_step; /**< Step size between layers. */
int token_step; /**< Step size between tokens. */
int layer_offset; /**< Offset value for layers. */
/**
* @brief Default constructor for KVCacheConfig.
*
* Initializes the configuration with default values. This constructor
* does not initialize any member variables explicitly.
*/
KVCacheConfig() = default;
/**
* @brief Parameterized constructor for KVCacheConfig.
*
* This constructor initializes the configuration with specific values
* for all member variables.
*
* @param layer_num The number of layers in the model.
* @param kv_head_num The number of heads in the KV Cache.
* @param q_head_num The number of heads in the query.
* @param head_dim The dimension of each head.
* @param block_len The length of each block in the cache.
* @param anchor_num The number of anchors used in attention.
* @param anchor_type The type of anchors used in the attention mechanism.
* @param kv_type The data type of the KV Cache (e.g., fp16, q8_0).
* @param retrieval_type The type of retrieval strategy used in the cache.
* @param layer_step The step size between layers.
* @param token_step The step size between tokens.
* @param layer_offset The offset value for layers.
* @param max_block_num The maximum number of blocks that can be allocated.
* @param max_batch_size The maximum batch size that can be processed.
* @param max_thread_num The maximum number of threads that can be used.
*/
KVCacheConfig(int layer_num, int kv_head_num, int q_head_num, int head_dim,
int block_len, int anchor_num, AnchorType anchor_type,
ggml_type kv_type, RetrievalType retrieval_type,
int layer_step, int token_step, int layer_offset,
int max_block_num, int max_batch_size, int max_thread_num);
};
/**
* @class KVCache
* @brief Manages the Key-Value (KV) Cache used in attention mechanisms.
*
* The KVCache class provides functionality for managing the Key-Value Cache,
* including resizing the cache, retrieving configuration parameters, and
* updating internal states. This class is typically used in transformer models
* to store and manage past key and value states for efficient attention
* computations.
*/
class KVCache {
public:
/**
* @brief Constructs a KVCache object with the given configuration.
*
* Initializes the KVCache with the specified configuration parameters,
* such as the number of layers, heads, head dimensions, and other
* relevant settings.
*
* @param config The configuration object containing initialization
* parameters.
*/
KVCache(KVCacheConfig config);
/**
* @brief Resizes the number of threads used by the cache.
*
* This function adjusts the number of threads that the cache can utilize.
* It allows dynamic reconfiguration of the parallel processing capabilities
* based on the current workload or system resources.
*
* @param thread_num The new number of threads to use.
*/
void ThreadResize(int thread_num);
/**
* @brief Resizes the batch size managed by the cache.
*
* This function adjusts the batch size that the cache can handle. It
* is useful when the input batch size changes dynamically, allowing
* the cache to be reconfigured accordingly.
*
* @param batch_size The new batch size.
*/
void BatchResize(int batch_size);
/**
* @brief Resizes the number of blocks managed by the cache.
*
* This function adjusts the number of blocks that the cache can manage.
* It allows dynamic reconfiguration of the block structure based on the
* current sequence length or other factors.
*
* @param block_num The new number of blocks.
*/
void BlockResize(int block_num);
/**
* @brief Gets the number of layers in the cache.
*
* @return The number of layers configured in the cache.
*/
int get_layer_num() { return config_.layer_num; }
/**
* @brief Gets the number of KV heads in the cache.
*
* @return The number of KV heads configured in the cache.
*/
int get_kv_head_num() { return config_.kv_head_num; }
/**
* @brief Gets the number of query heads in the cache.
*
* @return The number of query heads configured in the cache.
*/
int get_q_head_num() { return config_.q_head_num; }
/**
* @brief Gets the dimension of each head in the cache.
*
* @return The dimension of each head.
*/
int get_head_dim() { return config_.head_dim; }
/**
* @brief Gets the length of each block in the cache.
*
* @return The length of each block.
*/
int get_block_len() { return config_.block_len; }
/**
* @brief Gets the number of blocks for a specific layer.
*
* @param layer_id The ID of the layer for which to retrieve the block
* number.
* @return The number of blocks in the specified layer.
*/
int get_block_num(int layer_id) { return past_block_num_[layer_id]; }
/**
* @brief Gets the number of anchors in the cache.
*
* @return The number of anchors configured in the cache.
*/
int get_anchor_num() { return config_.anchor_num; }
/**
* @brief Gets the total length of the cache.
*
* @return The total length of the cache.
*/
int get_cache_total_len() { return cache_total_len_; }
/**
* @brief Gets the total number of blocks in the cache.
*
* This function computes and returns the total number of blocks in the
* cache based on the total cache length and the block length configuration.
*
* @return The total number of blocks in the cache.
*/
int get_cache_total_block_num() {
return (cache_total_len_ + config_.block_len - 1) / config_.block_len;
}
/**
* @brief Updates the total length of the cache.
*
* This function sets a new total length for the cache, allowing dynamic
* adjustment of the cache size during runtime.
*
* @param cache_total_len The new total length of the cache.
*/
void update_cache_total_len(int cache_total_len) {
cache_total_len_ = cache_total_len;
}
void attn(const ggml_fp16_t *q_in, ggml_fp16_t *output, float *attn_lse,
int layer_idx, int generate_token_idx, int q_len, int batch_size,
int max_block_num, int *block_table, int *cache_seqlens,
int pick_block_num, int init_block_num, int local_block_num,
Backend *backend);
void update_kvcache_one_block_fp16(const ggml_fp16_t *k_in,
const ggml_fp16_t *v_in, int layer_id,
int block_idx, Backend *backend);
void get_kvcache_one_block_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,
int layer_id, int block_idx,
Backend *backend);
void update_importance_one_block(const ggml_fp16_t *importance,
int layer_id, int block_idx,
Backend *backend);
void get_importance_one_block(ggml_fp16_t *importance, int layer_id,
int block_idx, Backend *backend);
void get_anchor_one_block(ggml_fp16_t *anchor, int layer_id, int block_idx,
Backend *backend);
void update_anchor_one_block(const ggml_fp16_t *anchor, int layer_id,
int block_idx, Backend *backend);
void calc_anchor_all_layers(int *block_table, int *cache_seqlens,
int batch_size, int max_block_num,
Backend *backend);
void load_kvcache(std::string tensor_file_path, Backend *backend);
void dump_kvcache(int *block_table, int cache_total_len,
std::string tensor_file_path, Backend *backend);
void get_and_update_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,
int layer_id, int *block_table,
int batch_size, int max_block_num,
int *cache_seqlens, int q_len,
Backend *backend);
void get_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in, int layer_id,
int *block_table, int batch_size, int max_block_num,
int *cache_seqlens, Backend *backend);
void update_kvcache_fp16(const ggml_fp16_t *k_in, const ggml_fp16_t *v_in,
int layer_id, int *block_table, int batch_size,
int max_block_num, int *cache_seqlens, int q_len,
Backend *backend);
void update_importance(const ggml_fp16_t *importance, int layer_id,
int *block_table, int batch_size, int max_block_num,
int *offset, int width, Backend *backend);
void attn_with_kvcache(const ggml_fp16_t *q_in, const ggml_fp16_t *k_in,
const ggml_fp16_t *v_in, ggml_fp16_t *output,
float *attn_lse, int layer_idx,
int generate_token_idx, int q_len, int batch_size,
int max_block_num, int *block_table,
int *cache_seqlens, int topk, int local,
Backend *backend);
void clear_importance_all_layers(int *block_table, int *cache_seqlens,
int batch_size, int max_block_num,
Backend *backend);
void clear_kvcache_all_layers(int *block_table, int *cache_seqlens,
int batch_size, int max_block_num,
Backend *backend);
void get_sincos(ggml_fp16_t *sin, ggml_fp16_t *cos, int seqlen);
void get_attn_sparsity(const ggml_fp16_t *q_in, float *attn_sparsity,
int layer_idx, int generate_token_idx, int q_len,
int batch_size, int max_block_num, int *block_table,
int *cache_seqlens, int *block_table_origin,
int *cache_seqlens_origin, int max_block_num_origin,
int topk, int local, Backend *backend);
void get_all_kvcache_one_layer(int layer_id, ggml_fp16_t *k_in,
ggml_fp16_t *v_in, Backend *backend);
private:
// Persistent data
KVCacheConfig config_;
int n_gqa_; // q_head_num / kv_head_num
int cache_total_len_; // Number of tokens in cache
std::vector<uint64_t> past_block_num_; // [layer_num]
std::vector<std::vector<std::vector<std::vector<block_q4_0>>>>
k_cache_q4; // [layer_num, kv_head_num, past_block_num, block_len *
// (head_dim / QK_4)]
std::vector<std::vector<std::vector<std::vector<block_q4_0>>>>
v_cache_q4; // [layer_num, kv_head_num, past_block_num, head_dim *
// (block_len / QK_4)]
std::vector<std::vector<std::vector<std::vector<block_q8_0>>>>
k_cache_q8; // [layer_num, kv_head_num, past_block_num, block_len *
// (head_dim / QK_8)]
std::vector<std::vector<std::vector<std::vector<block_q8_0>>>>
v_cache_q8; // [layer_num, kv_head_num, past_block_num, head_dim *
// (block_len / QK_8)]
std::vector<std::vector<std::vector<std::vector<ggml_fp16_t>>>>
k_cache_fp16_; // [layer_num, kv_head_num, past_block_num, block_len *
// head_dim]
std::vector<std::vector<std::vector<std::vector<ggml_fp16_t>>>>
v_cache_fp16_; // [layer_num, kv_head_num, past_block_num, head_dim *
// block_len]
std::vector<std::vector<std::vector<std::vector<ggml_fp16_t>>>>
importance_; // [layer_num, past_block_num, block_len,
// attention_head_num]
std::vector<ggml_fp16_t>
anchor_; // [layer_num * past_block_num * anchor_num *
// attention_head_num * head_dim]
// Runtime data
int64_t layer_id_;
int64_t block_idx_;
int *block_table_;
uint64_t block_num_;
int max_block_num_after_retrieval_;
// Rotary positional embeddings
std::vector<std::vector<ggml_fp16_t>> sin_; // [seq_len, head_dim]
std::vector<std::vector<ggml_fp16_t>> cos_; // [seq_len, head_dim]
// update/get
int seq_len_;
uint16_t *k_scales_; // q4_0
uint8_t *k_in_; // q4_0
uint16_t *v_scales_; // q4_0
uint8_t *v_in_; // q4_0
uint16_t *k_data_; // fp16
uint16_t *v_data_; // fp16
uint16_t *importance_data_; // fp16
uint16_t *anchor_data_; // fp16
// sparsity = (sigma(block lse / lse))
std::vector<std::vector<std::vector<float>>>
block_lse_; // [batch_size, max_block_num, q_head_num]
std::vector<std::vector<float>> attn_sparsity_; // [batch_size, q_head_num]
// attn
std::vector<std::vector<float>>
avg_q; // [batch_size, q_head_num * head_dim]
std::vector<std::vector<ggml_fp16_t>>
avg_q_fp16; // [batch_size, q_head_num * head_dim]
std::vector<
std::priority_queue<std::pair<float, int>,
std::vector<std::pair<float, int>>, std::greater<>>>
top_similar_block_;
std::vector<std::vector<float>> block_similar_;
std::vector<std::vector<std::vector<float>>> block_similar_kv_head_;
std::vector<std::vector<std::vector<float>>> block_similar_q_head_;
std::vector<int> cache_seqlens_; // [batch_size]
std::vector<int> selected_blocks_num_history_; // [layer_num // layer_step]
std::vector<std::vector<std::vector<int>>> selected_blocks_history_;
// [layer_num // layer_step, batch_size, max_block_num]
std::vector<std::vector<std::vector<std::vector<int>>>>
selected_blocks_history_kvhead_; // [layer_num // layer_step,
// batch_size, max_block_num,
// kv_head_num]
std::vector<std::vector<int>>
block_table_before_retrieval_; // [batch_size, max_block_num]
std::vector<std::vector<int>>
block_table_after_retrieval_; // [batch_size, pick_block_num]
std::vector<std::vector<std::vector<int>>>
block_table_before_retrieval_qhead_; // [batch_size, max_block_num,
// q_head_num]
std::vector<std::vector<std::vector<int>>>
block_table_after_retrieval_qhead_; // [batch_size, pick_block_num,
// q_head_num]
std::vector<std::vector<std::vector<int>>>
block_table_before_retrieval_kvhead_; // [batch_size, max_block_num,
// kv_head_num]
std::vector<std::vector<std::vector<int>>>
block_table_after_retrieval_kvhead_; // [batch_size, pick_block_num,
// kv_head_num]
std::vector<std::vector<std::unique_ptr<std::mutex>>>
mutex_; // [batch_size, kv_head_num]
std::vector<std::vector<std::vector<block_q8_0>>>
q_q8_0_; // [batch_size, kv_head_num, n_gqa * head_dim / QK8_0]
std::vector<std::vector<std::vector<float>>>
q_fp32_; // [batch_size, kv_head_num, n_gqa * head_dim]
std::vector<std::vector<std::vector<float>>>
output_fp32_; // [batch_size, kv_head_num, n_gqa * head_dim]
std::vector<std::vector<std::vector<float>>>
attn_lse_; // [batch_size, kv_head_num, n_gqa]
std::vector<std::pair<int, int>> thread_cur_head_idx_; // [thread_num]
std::vector<std::vector<block_q8_0>>
thread_local_output_q8_0_; // [thread_num, n_gqa * head_dim / QK8_0]
std::vector<std::vector<float>>
thread_local_attn_score_; // [thread_num, n_gqa * block_len]
std::vector<std::vector<float>>
thread_local_output_fp32_; // [thread_num, n_gqa * head_dim]
std::vector<std::vector<float>>
thread_local_attn_lse_; // [thread_num, n_gqa]
std::vector<std::vector<float>>
thread_local_cur_output_fp32_; // [thread_num, n_gqa * head_dim]
std::vector<std::vector<float>>
thread_local_cur_attn_lse_; // [thread_num, n_gqa]
std::vector<std::vector<uint8_t>>
thread_local_attn_mask_; // [thread_num, block_len // 8]
std::vector<std::vector<char>>
thread_local_draft_; // [thread_num, 2 * n_gqa * block_len + 6 * n_gqa *
// head_dim + 2 * block_len * head_dim]
// tmp space
std::vector<float> q_fp32; // [n_gqa * head_dim]
void quantize_q_(const uint16_t *q_in_data, int batch_size);
void attn_initialize_layer_(int batch_size, int layer_idx, int *block_table,
int &max_block_num, int *cache_seqlens);
void attn_initialize_kvhead_(int batch_size, int layer_idx,
int *block_table, int &max_block_num,
int *cache_seqlens);
void retrieval_kvcache_layer_(const uint16_t *q_in_data, int init_block_num,
int local_block_num, int pick_block_num,
int q_len, int generate_token_idx,
int batch_size, int layer_idx,
int *cache_seqlens, int &max_block_num,
Backend *backend);
void retrieval_kvcache_kvhead_(const uint16_t *q_in_data,
int init_block_num, int local_block_num,
int pick_block_num, int q_len,
int generate_token_idx, int batch_size,
int layer_idx, int *cache_seqlens,
int &max_block_num, Backend *backend);
void calculate_block_similarity_layer_(
const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len,
int max_block_num, int *cache_seqlens, int init_block_num,
int local_block_num, int pick_block_num, Backend *backend);
void calculate_block_similarity_kvhead_(
const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len,
int max_block_num, int *cache_seqlens, int init_block_num,
int local_block_num, int pick_block_num, Backend *backend);
void select_block_layer_(int batch_size, int layer_idx, int max_block_num,
int init_block_num, int local_block_num,
int pick_block_num);
void select_block_kvhead_(int batch_size, int layer_idx, int max_block_num,
int init_block_num, int local_block_num,
int pick_block_num);
void calculate_sparsity_layer_(const uint16_t *q_in_data,
float *attn_sparsity, int batch_size,
int max_block_num, int *block_table,
int *cache_seqlens, Backend *backend);
void calculate_sparsity_kvhead_(const uint16_t *q_in_data,
float *attn_sparsity, int batch_size,
int max_block_num, int *block_table,
int *cache_seqlens, Backend *backend);
void attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output,
float *attn_lse, int batch_size, Backend *backend);
void attention_layer_(const uint16_t *q_in_data, ggml_fp16_t *output,
float *attn_lse, int batch_size, Backend *backend);
/**
* @brief Computes attention with KV cache for one block.
*
* This function performs attention computation for one block using KV
* cache. The function supports different data types for Q, K, and V caches,
* and provides options for quantization. The function does not perform any
* dynamic memory allocation internally, so all necessary buffers must be
* pre-allocated externally.
*
* @param head_dim The dimension of the head.
* @param bsz The batch size.
* @param q_type The data type of Q (GGML data type). Only supports fp16 and
* q8_0.
* @param q Pointer to the Q tensor [bsz, head_dim]. The quantization is
* always applied along the head_dim dimension. The size must be
* bsz * head_dim/32 * qtype_size. If head_dim % 32 != 0, an error
* will be raised.
* @param past_kv_len The length of the past KV cache.
* @param past_kv_offset The offset in the past KV cache.
* @param is_full_attn Boolean flag indicating whether to use full attention
* (true for full 1 mask).
* @param attn_mask Pointer to the attention mask [bsz, past_kv_len]. If
* is_full_attn = false, a bit matrix is passed to
* represent the mask.
* @param k_type The data type of K cache (GGML data type). Only supports
* fp16, q4_0, and q8_0.
* @param k_quant_type Quantization type for K cache. 0 for per_token, 1 for
* per_channel. Other values will raise an error.
* @param k_cache Pointer to the K cache tensor [seq_len, head_dim]. If
* quant_type == 0, head_dim % 32 must be 0. If quant_type ==
* 1, seq_len % 32 must be 0.
* @param num_k_anchor The number of K anchors. If num_k_anchor == 0, it
* means no anchor is present.
* @param k_cache_anchors Pointer to the K cache anchors [num_k_anchor,
* head_dim]. The k_anchor_type must be fp16.
* @param k_cache_anchor_pos Pointer to the K cache anchor positions. Each
* token is associated with the nearest previous anchor position.
* @param v_type The data type of V cache (GGML data type).
* @param v_quant_type Quantization type for V cache.
* @param v_cache Pointer to the V cache tensor [head_dim, seq_len].
* @param num_v_anchor The number of V anchors.
* @param v_cache_anchors Pointer to the V cache anchors.
* @param v_cache_anchor_pos Pointer to the V cache anchor positions.
* @param attn_score Pre-allocated buffer for attention scores [bsz,
* past_kv_len].
* @param output Output tensor [bsz, head_dim] with the same type as q_type.
* @param lse Pre-allocated buffer [bsz] for the log-sum-exp of the
* attention scores.
* @param draft Pre-allocated temporary buffer. The buffer size should be
* enough to hold (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 *
* past_kv_len * head_dim + past_kv_len * head_dim / 32) bytes.
* @param rotary_angle Pointer to the rotary angle tensor.
* @param rotary_cos Pointer to the cosine values for rotary embedding.
* @param rotary_sin Pointer to the sine values for rotary embedding.
*/
void attn_with_kvcache_one_block_(
int head_dim, int bsz,
ggml_type q_type, // GGML data type of `Q`, only supports fp16 and q8_0
// [bsz, head_dim]
// Quantization is always on the head_dim dimension (per_token). If
// head_dim % 32 != 0, an error will be raised. The size must be bsz *
// head_dim/32 * qtype_size.
const void *q,
int past_kv_len, int past_kv_offset,
bool is_full_attn, // true indicates a full 1 mask
// If is_full_attn = false, a bit matrix representing the mask is
// passed. [bsz, past_kv_len]
const uint8_t *attn_mask,
ggml_type k_type, // GGML data type of `K Cache`, only supports fp16,
// q4_0, q8_0
int k_quant_type, // 0 for per_token, 1 for per_channel, others raise an
// error
// [seq_len, head_dim]
// If quant_type == 0, head_dim % 32 must be 0.
// If quant_type == 1, seq_len % 32 must be 0.
const void *k_cache,
// k_anchor_type must be fp16
int num_k_anchor, // num_k_anchor == 0 indicates no anchor
// [num_k_anchor, head_dim]
const void *k_cache_anchors,
// Each token is associated with the nearest previous position's anchor,
// with the same distance.
const int *k_cache_anchor_pos,
// v_cache similar to k_cache
ggml_type v_type, int v_quant_type,
// [head_dim, seq_len]
const void *v_cache, int num_v_anchor, const void *v_cache_anchors,
const int *v_cache_anchor_pos,
// Pre-allocated buffer for intermediate calculations [bsz,
// past_kv_len]. No malloc is performed inside this function.
float *attn_score,
// Output: [bsz, head_dim], with the same type as q_type
void *output,
// [bsz]
float *lse,
// Pre-allocated temporary buffer with sufficient size:
// (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 * past_kv_len *
// head_dim + past_kv_len * head_dim / 32) bytes.
void *draft,
// Apply rotary embedding online
const int *rotary_angle, const void *rotary_cos, const void *rotary_sin
// rotary_cos=None,
// rotary_sin=None,
// cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
// cache_batch_idx: Optional[torch.Tensor] = None,
// rotary_interleaved=True,
// // Not supported for now
// window_size=(-1, -1), # -1 means infinite context window
// alibi_slopes=None,
);
};
/**
* @brief Scales a float32 vector by a given scalar value.
*
* This function multiplies each element of the input vector `y` by a scalar
* `v`. It uses platform-specific optimizations if available, such as Apple's
* Accelerate framework or SIMD instructions. If no specific optimization is
* available, the function falls back to a simple scalar multiplication loop.
*
* @param n The number of elements in the vector `y`.
* @param y The input vector to be scaled. The result will be stored in the same
* vector.
* @param v The scalar value by which to scale the vector.
*/
void ggml_vec_scale_f32(const int n, float *y, const float v);
#endif
\ No newline at end of file
/**
* @Description :
* @Author : Jianwei Dong
* @Date : 2024-08-26 22:47:06
* @Version : 1.0.0
* @LastEditors : Jianwei Dong
* @LastEditTime : 2024-08-26 22:47:06
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "kvcache.h"
void KVCache::attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output,
float *attn_lse, int batch_size,
Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
seq_len_ = config_.block_len;
backend->do_work_stealing_job(
batch_size * config_.kv_head_num * max_block_num_after_retrieval_,
[&](int thread_id) {
thread_cur_head_idx_[thread_id].first = -1;
thread_cur_head_idx_[thread_id].second = -1;
},
[&](int task_id) {
int batch_id = task_id / (config_.kv_head_num *
max_block_num_after_retrieval_);
int head_id = (task_id % (config_.kv_head_num *
max_block_num_after_retrieval_)) /
max_block_num_after_retrieval_;
int block_id = task_id % max_block_num_after_retrieval_;
int thread_id = Backend::thread_local_id;
// If the block is out of the sequence length, skip it.
if (cache_seqlens_[batch_id] / config_.block_len < block_id) {
return;
}
int block_idx =
block_table_after_retrieval_kvhead_[batch_id][block_id]
[head_id];
if (cache_seqlens_[batch_id] / config_.block_len == block_id) {
int seq_len = cache_seqlens_[batch_id] % config_.block_len;
if (seq_len == 0)
return;
// Prepare the attention mask for the last block.
int full_blocks = seq_len / 8;
int remaining_bits = seq_len % 8;
// Fill full blocks with 1s
for (int i = 0; i < full_blocks; ++i) {
thread_local_attn_mask_[thread_id][i] = 0xFF;
}
// Fill the remaining bits in the next block
if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {
thread_local_attn_mask_[thread_id][full_blocks] =
(1 << remaining_bits) - 1;
} else {
thread_local_attn_mask_[thread_id][full_blocks] = 0;
}
for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {
thread_local_attn_mask_[thread_id][i] = 0;
}
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
(void *)&q_in_data[batch_id * config_.kv_head_num *
n_gqa_ * config_.head_dim +
head_id * n_gqa_ * config_.head_dim],
seq_len_, 0, false,
thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_F16, 0,
k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_F16, 1,
v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num,
GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
seq_len_, 0, false,
thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_Q4_0, 0,
k_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_Q4_0, 1,
v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
dequantize_row_q8_0(
thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
n_gqa_ * config_.head_dim);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num,
GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
seq_len_, 0, false,
thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_Q8_0, 0,
k_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_Q8_0, 1,
v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
dequantize_row_q8_0(
thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
n_gqa_ * config_.head_dim);
}
} else {
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
(void *)&q_in_data[batch_id * config_.kv_head_num *
n_gqa_ * config_.head_dim +
head_id * n_gqa_ * config_.head_dim],
seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0,
k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_F16, 1,
v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num,
GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0,
k_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_Q4_0, 1,
v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
dequantize_row_q8_0(
thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
n_gqa_ * config_.head_dim);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num,
GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0,
k_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_Q8_0, 1,
v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
dequantize_row_q8_0(
thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
n_gqa_ * config_.head_dim);
}
}
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
int cur_head_id = thread_cur_head_idx_[thread_id].second;
if (batch_id == cur_batch_idx && head_id == cur_head_id) {
for (int i = 0; i < n_gqa_; i++) {
float new_attn_lse =
thread_local_cur_attn_lse_[thread_id][i] +
std::log(
1.0 +
std::exp(thread_local_attn_lse_[thread_id][i] -
thread_local_cur_attn_lse_[thread_id][i]));
ggml_vec_scale_f32(
config_.head_dim,
thread_local_cur_output_fp32_[thread_id].data() +
i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] -
new_attn_lse));
ggml_vec_scale_f32(
config_.head_dim,
thread_local_output_fp32_[thread_id].data() +
i * config_.head_dim,
std::exp(thread_local_attn_lse_[thread_id][i] -
new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
thread_local_cur_output_fp32_[thread_id]
[i * config_.head_dim +
j] +=
thread_local_output_fp32_[thread_id]
[i * config_.head_dim + j];
}
thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;
}
} else {
if (cur_batch_idx != -1) {
mutex_[cur_batch_idx][cur_head_id]->lock();
for (int i = 0; i < n_gqa_; i++) {
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <
1e-6) {
attn_lse_[cur_batch_idx][cur_head_id][i] =
thread_local_cur_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id]
[i * config_.head_dim + j] =
thread_local_cur_output_fp32_
[thread_id]
[i * config_.head_dim + j];
}
continue;
}
float new_attn_lse =
attn_lse_[cur_batch_idx][cur_head_id][i] +
std::log(
1.0 +
std::exp(
thread_local_cur_attn_lse_[thread_id][i] -
attn_lse_[cur_batch_idx][cur_head_id][i]));
ggml_vec_scale_f32(
config_.head_dim,
output_fp32_[cur_batch_idx][cur_head_id].data() +
i * config_.head_dim,
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -
new_attn_lse));
ggml_vec_scale_f32(
config_.head_dim,
thread_local_cur_output_fp32_[thread_id].data() +
i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] -
new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id]
[i * config_.head_dim + j] +=
thread_local_cur_output_fp32_
[thread_id][i * config_.head_dim + j];
}
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
}
mutex_[cur_batch_idx][cur_head_id]->unlock();
}
thread_cur_head_idx_[thread_id].first = batch_id;
thread_cur_head_idx_[thread_id].second = head_id;
for (int i = 0; i < n_gqa_; i++) {
thread_local_cur_attn_lse_[thread_id][i] =
thread_local_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
thread_local_cur_output_fp32_
[thread_id][i * config_.head_dim + j] =
thread_local_output_fp32_[thread_id]
[i * config_.head_dim +
j];
}
}
}
},
// Merge the results of the remaining blocks.
[&](int thread_id) {
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
int cur_head_id = thread_cur_head_idx_[thread_id].second;
if (cur_head_id != -1) {
mutex_[cur_batch_idx][cur_head_id]->lock();
for (int i = 0; i < n_gqa_; i++) {
float new_attn_lse;
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <
1e-6) {
attn_lse_[cur_batch_idx][cur_head_id][i] =
thread_local_cur_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id]
[i * config_.head_dim + j] =
thread_local_cur_output_fp32_
[thread_id]
[i * config_.head_dim + j];
}
continue;
}
new_attn_lse =
attn_lse_[cur_batch_idx][cur_head_id][i] +
std::log(
1.0 +
std::exp(thread_local_cur_attn_lse_[thread_id][i] -
attn_lse_[cur_batch_idx][cur_head_id][i]));
ggml_vec_scale_f32(
config_.head_dim,
output_fp32_[cur_batch_idx][cur_head_id].data() +
i * config_.head_dim,
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -
new_attn_lse));
ggml_vec_scale_f32(
config_.head_dim,
thread_local_cur_output_fp32_[thread_id].data() +
i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] -
new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id]
[i * config_.head_dim + j] +=
thread_local_cur_output_fp32_[thread_id]
[i * config_.head_dim +
j];
}
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
}
mutex_[cur_batch_idx][cur_head_id]->unlock();
}
});
// move the results to output and attn_lse
uint16_t *output_data = reinterpret_cast<uint16_t *>(output);
float *attn_lse_data = attn_lse;
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int i = 0; i < config_.kv_head_num; i++) {
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
output_data[batch_idx * config_.kv_head_num * n_gqa_ *
config_.head_dim +
i * n_gqa_ * config_.head_dim + j] =
GGML_FP32_TO_FP16(output_fp32_[batch_idx][i][j]);
}
for (int j = 0; j < n_gqa_; j++) {
attn_lse_data[batch_idx * config_.kv_head_num * n_gqa_ +
i * n_gqa_ + j] = attn_lse_[batch_idx][i][j];
}
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of computing attention: %f s\n", layer_idx,
// diff.count());
}
void KVCache::attention_layer_(const uint16_t *q_in_data, ggml_fp16_t *output,
float *attn_lse, int batch_size,
Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
seq_len_ = config_.block_len;
backend->do_work_stealing_job(
batch_size * config_.kv_head_num * max_block_num_after_retrieval_,
[&](int thread_id) {
thread_cur_head_idx_[thread_id].first = -1;
thread_cur_head_idx_[thread_id].second = -1;
},
[&](int task_id) {
int batch_id = task_id / (config_.kv_head_num *
max_block_num_after_retrieval_);
int head_id = (task_id % (config_.kv_head_num *
max_block_num_after_retrieval_)) /
max_block_num_after_retrieval_;
int block_id = task_id % max_block_num_after_retrieval_;
int thread_id = Backend::thread_local_id;
// If the block is out of the sequence length, skip it.
if (cache_seqlens_[batch_id] / config_.block_len < block_id) {
return;
}
int block_idx = block_table_after_retrieval_[batch_id][block_id];
if (cache_seqlens_[batch_id] / config_.block_len == block_id) {
int seq_len = cache_seqlens_[batch_id] % config_.block_len;
if (seq_len == 0)
return;
// Prepare the attention mask for the last block.
int full_blocks = seq_len / 8;
int remaining_bits = seq_len % 8;
// Fill full blocks with 1s
for (int i = 0; i < full_blocks; ++i) {
thread_local_attn_mask_[thread_id][i] = 0xFF;
}
// Fill the remaining bits in the next block
if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {
thread_local_attn_mask_[thread_id][full_blocks] =
(1 << remaining_bits) - 1;
} else {
thread_local_attn_mask_[thread_id][full_blocks] = 0;
}
for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {
thread_local_attn_mask_[thread_id][i] = 0;
}
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
(void *)&q_in_data[batch_id * config_.kv_head_num *
n_gqa_ * config_.head_dim +
head_id * n_gqa_ * config_.head_dim],
seq_len_, 0, false,
thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_F16, 0,
k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_F16, 1,
v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num,
GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
seq_len_, 0, false,
thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_Q4_0, 0,
k_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_Q4_0, 1,
v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
dequantize_row_q8_0(
thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
n_gqa_ * config_.head_dim);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num,
GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
seq_len_, 0, false,
thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_Q8_0, 0,
k_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_Q8_0, 1,
v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
dequantize_row_q8_0(
thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
n_gqa_ * config_.head_dim);
}
} else {
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
(void *)&q_in_data[batch_id * config_.kv_head_num *
n_gqa_ * config_.head_dim +
head_id * n_gqa_ * config_.head_dim],
seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0,
k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_F16, 1,
v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num,
GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0,
k_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_Q4_0, 1,
v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
dequantize_row_q8_0(
thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
n_gqa_ * config_.head_dim);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num,
GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0,
k_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_Q8_0, 1,
v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
dequantize_row_q8_0(
thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
n_gqa_ * config_.head_dim);
}
}
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
int cur_head_id = thread_cur_head_idx_[thread_id].second;
if (batch_id == cur_batch_idx && head_id == cur_head_id) {
for (int i = 0; i < n_gqa_; i++) {
float new_attn_lse =
thread_local_cur_attn_lse_[thread_id][i] +
std::log(
1.0 +
std::exp(thread_local_attn_lse_[thread_id][i] -
thread_local_cur_attn_lse_[thread_id][i]));
ggml_vec_scale_f32(
config_.head_dim,
thread_local_cur_output_fp32_[thread_id].data() +
i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] -
new_attn_lse));
ggml_vec_scale_f32(
config_.head_dim,
thread_local_output_fp32_[thread_id].data() +
i * config_.head_dim,
std::exp(thread_local_attn_lse_[thread_id][i] -
new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
thread_local_cur_output_fp32_[thread_id]
[i * config_.head_dim +
j] +=
thread_local_output_fp32_[thread_id]
[i * config_.head_dim + j];
}
thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;
}
} else {
if (cur_batch_idx != -1) {
mutex_[cur_batch_idx][cur_head_id]->lock();
for (int i = 0; i < n_gqa_; i++) {
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <
1e-6) {
attn_lse_[cur_batch_idx][cur_head_id][i] =
thread_local_cur_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id]
[i * config_.head_dim + j] =
thread_local_cur_output_fp32_
[thread_id]
[i * config_.head_dim + j];
}
continue;
}
float new_attn_lse =
attn_lse_[cur_batch_idx][cur_head_id][i] +
std::log(
1.0 +
std::exp(
thread_local_cur_attn_lse_[thread_id][i] -
attn_lse_[cur_batch_idx][cur_head_id][i]));
ggml_vec_scale_f32(
config_.head_dim,
output_fp32_[cur_batch_idx][cur_head_id].data() +
i * config_.head_dim,
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -
new_attn_lse));
ggml_vec_scale_f32(
config_.head_dim,
thread_local_cur_output_fp32_[thread_id].data() +
i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] -
new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id]
[i * config_.head_dim + j] +=
thread_local_cur_output_fp32_
[thread_id][i * config_.head_dim + j];
}
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
}
mutex_[cur_batch_idx][cur_head_id]->unlock();
}
thread_cur_head_idx_[thread_id].first = batch_id;
thread_cur_head_idx_[thread_id].second = head_id;
for (int i = 0; i < n_gqa_; i++) {
thread_local_cur_attn_lse_[thread_id][i] =
thread_local_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
thread_local_cur_output_fp32_
[thread_id][i * config_.head_dim + j] =
thread_local_output_fp32_[thread_id]
[i * config_.head_dim +
j];
}
}
}
},
// Merge the results of the remaining blocks.
[&](int thread_id) {
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
int cur_head_id = thread_cur_head_idx_[thread_id].second;
if (cur_head_id != -1) {
mutex_[cur_batch_idx][cur_head_id]->lock();
for (int i = 0; i < n_gqa_; i++) {
float new_attn_lse;
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <
1e-6) {
attn_lse_[cur_batch_idx][cur_head_id][i] =
thread_local_cur_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id]
[i * config_.head_dim + j] =
thread_local_cur_output_fp32_
[thread_id]
[i * config_.head_dim + j];
}
continue;
}
new_attn_lse =
attn_lse_[cur_batch_idx][cur_head_id][i] +
std::log(
1.0 +
std::exp(thread_local_cur_attn_lse_[thread_id][i] -
attn_lse_[cur_batch_idx][cur_head_id][i]));
ggml_vec_scale_f32(
config_.head_dim,
output_fp32_[cur_batch_idx][cur_head_id].data() +
i * config_.head_dim,
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -
new_attn_lse));
ggml_vec_scale_f32(
config_.head_dim,
thread_local_cur_output_fp32_[thread_id].data() +
i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] -
new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id]
[i * config_.head_dim + j] +=
thread_local_cur_output_fp32_[thread_id]
[i * config_.head_dim +
j];
}
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
}
mutex_[cur_batch_idx][cur_head_id]->unlock();
}
});
// move the results to output and attn_lse
uint16_t *output_data = reinterpret_cast<uint16_t *>(output);
float *attn_lse_data = attn_lse;
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int i = 0; i < config_.kv_head_num; i++) {
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
output_data[batch_idx * config_.kv_head_num * n_gqa_ *
config_.head_dim +
i * n_gqa_ * config_.head_dim + j] =
GGML_FP32_TO_FP16(output_fp32_[batch_idx][i][j]);
}
for (int j = 0; j < n_gqa_; j++) {
attn_lse_data[batch_idx * config_.kv_head_num * n_gqa_ +
i * n_gqa_ + j] = attn_lse_[batch_idx][i][j];
}
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of computing attention: %f s\n", layer_id_,
// diff.count());
}
void KVCache::attn(const ggml_fp16_t *q_in, ggml_fp16_t *output,
float *attn_lse, int layer_idx, int generate_token_idx,
int q_len, int batch_size, int max_block_num,
int *block_table, int *cache_seqlens, int pick_block_num,
int init_block_num, int local_block_num, Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_idx;
batch_size = batch_size * q_len;
const uint16_t *q_in_data = const_cast<const uint16_t *>(q_in);
quantize_q_(q_in_data, batch_size);
if (config_.retrieval_type == RetrievalType::LAYER) {
attn_initialize_layer_(batch_size, layer_idx, block_table,
max_block_num, cache_seqlens);
retrieval_kvcache_layer_(q_in_data, init_block_num, local_block_num,
pick_block_num, q_len, generate_token_idx,
batch_size, layer_idx, cache_seqlens,
max_block_num, backend);
attention_layer_(q_in_data, output, attn_lse, batch_size, backend);
} else if (config_.retrieval_type == RetrievalType::KVHEAD) {
attn_initialize_kvhead_(batch_size, layer_idx, block_table,
max_block_num, cache_seqlens);
retrieval_kvcache_kvhead_(q_in_data, init_block_num, local_block_num,
pick_block_num, q_len, generate_token_idx,
batch_size, layer_idx, cache_seqlens,
max_block_num, backend);
attention_kvhead_(q_in_data, output, attn_lse, batch_size, backend);
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of computing attention: %f s\n", layer_idx,
// diff.count());
}
void KVCache::attn_with_kvcache(
const ggml_fp16_t *q_in, const ggml_fp16_t *k_in, const ggml_fp16_t *v_in,
ggml_fp16_t *output, float *attn_lse, int layer_idx, int generate_token_idx,
int q_len, int batch_size, int max_block_num, int *block_table,
int *cache_seqlens, int topk, int local, Backend *backend) {
// printf("attn_with_kvcache start\n");
assert(q_len == 1);
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_idx;
update_kvcache_fp16(k_in, v_in, layer_idx, block_table, batch_size,
max_block_num, cache_seqlens, q_len, backend);
// printf("update finished.\n");
// cache_seqlens memory is modified.
for (int i = 0; i < batch_size; i++) {
cache_seqlens[i] += q_len;
}
int init_block_num = 1;
if (config_.block_len <= 32) {
init_block_num = 64 / config_.block_len;
}
attn(q_in, output, attn_lse, layer_idx, generate_token_idx, q_len,
batch_size, max_block_num, block_table, cache_seqlens, topk,
init_block_num, local, backend);
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of computing attention with kvcache: %f s\n",
// layer_idx, diff.count());
}
void KVCache::quantize_q_(const uint16_t *q_in_data, int batch_size) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
// quantize q
for (int i = 0; i < config_.kv_head_num; i++) {
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
q_fp32_[batch_idx][i][j] = GGML_FP16_TO_FP32(
q_in_data[batch_idx * config_.kv_head_num * n_gqa_ *
config_.head_dim +
i * n_gqa_ * config_.head_dim + j]);
}
}
} else {
// quantize q
for (int i = 0; i < config_.kv_head_num; i++) {
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
q_fp32[j] = GGML_FP16_TO_FP32(
q_in_data[batch_idx * config_.kv_head_num * n_gqa_ *
config_.head_dim +
i * n_gqa_ * config_.head_dim + j]);
}
quantize_row_q8_0(q_fp32.data(), q_q8_0_[batch_idx][i].data(),
n_gqa_ * config_.head_dim);
}
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
// printf("time of quantizing q: %f s\n",
// std::chrono::duration<double>(end - start).count());
}
void KVCache::attn_initialize_layer_(int batch_size, int layer_idx,
int *block_table, int &max_block_num,
int *cache_seqlens) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
// initialize output_fp32_ and attn_lse_
for (int i = 0; i < config_.kv_head_num; i++) {
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
output_fp32_[batch_idx][i][j] = 0;
}
for (int j = 0; j < n_gqa_; j++) {
attn_lse_[batch_idx][i][j] = 0;
}
}
// clear top_similar_block_
while (!top_similar_block_[batch_idx].empty())
top_similar_block_[batch_idx].pop();
}
// get block_table_before_retrieval_ and cache_seqlens_
if (block_table == nullptr) {
max_block_num = past_block_num_[layer_idx];
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
if (cache_total_len_ != 0)
cache_seqlens_[batch_idx] = cache_total_len_;
else
cache_seqlens_[batch_idx] = max_block_num * config_.block_len;
for (int i = 0; i < max_block_num; i++) {
block_table_before_retrieval_[batch_idx][i] = i;
block_similar_[batch_idx][i] = 0;
}
}
} else {
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
cache_seqlens_[batch_idx] = cache_seqlens[batch_idx];
for (int i = 0; i < max_block_num; i++) {
block_table_before_retrieval_[batch_idx][i] =
block_table[batch_idx * max_block_num + i];
block_similar_[batch_idx][i] = 0;
}
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
// printf("layer %d time of initializing attention: %f s\n", layer_idx,
// std::chrono::duration<double>(end - start).count());
}
void KVCache::calculate_block_similarity_layer_(
const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len,
int max_block_num, int *cache_seqlens, int init_block_num,
int local_block_num, int pick_block_num, Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
if (batch_size == 1 &&
config_.anchor_num == 1) { // TODO: improve batch_size > 1
for (int batch_id = 0; batch_id < batch_size; batch_id++) {
if (q_len == 1) {
for (int j = 0; j < config_.head_dim * config_.q_head_num;
j++) {
avg_q[batch_id][j] = GGML_FP16_TO_FP32(
q_in_data[batch_id * q_len * config_.q_head_num *
config_.head_dim +
j]);
avg_q_fp16[batch_id][j] =
q_in_data[batch_id * q_len * config_.q_head_num *
config_.head_dim +
j];
}
} else {
for (int j = 0; j < config_.head_dim * config_.q_head_num;
j++) {
avg_q[batch_id][j] = 0;
}
for (int i = 0; i < q_len; i++) {
for (int j = 0; j < config_.head_dim; j++) {
avg_q[batch_id][j] += GGML_FP16_TO_FP32(
q_in_data[batch_id * q_len * config_.q_head_num *
config_.head_dim +
i * config_.q_head_num *
config_.head_dim +
j]);
}
}
for (int j = 0; j < config_.head_dim * config_.q_head_num;
j++) {
avg_q[batch_id][j] /= q_len;
avg_q_fp16[batch_id][j] =
GGML_FP32_TO_FP16(avg_q[batch_id][j]);
}
}
int seq_len = cache_seqlens_[batch_id];
int block_num = (seq_len / config_.block_len) - local_block_num -
init_block_num;
if (block_num <= 0) {
continue;
}
bool is_seq = true;
for (int i = init_block_num + 1;
i < (seq_len / config_.block_len) - local_block_num; i++) {
if (block_table_before_retrieval_[batch_id][i] !=
block_table_before_retrieval_[batch_id][i - 1] + 1) {
is_seq = false;
break;
}
}
if (is_seq) {
int nth = backend->get_thread_num();
backend->do_work_stealing_job(
nth, nullptr,
[&](int task_id) {
int ith = task_id;
bool ok = llamafile_sgemm(
block_num, 1, config_.q_head_num * config_.head_dim,
anchor_.data() +
(layer_idx * config_.max_block_num +
block_table_before_retrieval_
[batch_id][init_block_num]) *
config_.anchor_num * config_.q_head_num *
config_.head_dim,
config_.q_head_num * config_.head_dim,
avg_q_fp16[batch_id].data(),
config_.q_head_num * config_.head_dim,
block_similar_[batch_id].data() + init_block_num,
block_num, ith, nth, GGML_TASK_TYPE_COMPUTE,
GGML_TYPE_F16, GGML_TYPE_F16, GGML_TYPE_F32,
GGML_PREC_DEFAULT);
if (!ok) {
printf("llamafile_sgemm failed\n");
}
},
nullptr);
} else {
backend->do_work_stealing_job(
block_num, nullptr,
[&](int task_id) {
int block_id = task_id + init_block_num;
int block_idx =
block_table_before_retrieval_[batch_id][block_id];
bool ok = llamafile_sgemm(
1, 1, config_.q_head_num * config_.head_dim,
anchor_.data() +
(layer_idx * config_.max_block_num +
block_table_before_retrieval_[batch_id]
[block_idx]) *
config_.anchor_num * config_.q_head_num *
config_.head_dim,
config_.q_head_num * config_.head_dim,
avg_q_fp16[batch_id].data(),
config_.q_head_num * config_.head_dim,
block_similar_[batch_id].data() + block_id, 1, 0, 1,
GGML_TASK_TYPE_COMPUTE, GGML_TYPE_F16,
GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT);
if (!ok) {
printf("llamafile_sgemm failed\n");
}
},
nullptr);
}
}
} else {
backend->do_work_stealing_job(
batch_size * max_block_num, nullptr,
[&](int task_id) {
int batch_id = task_id / max_block_num;
int block_id = task_id % max_block_num;
int seq_len = cache_seqlens_[batch_id];
if (block_id < init_block_num ||
block_id >=
(seq_len / config_.block_len) - local_block_num) {
return;
}
int block_idx =
block_table_before_retrieval_[batch_id][block_id];
float sim = 0;
for (int head_id = 0; head_id < config_.q_head_num; head_id++) {
for (int i = 0; i < config_.head_dim; i++) {
float q_i = 0,
qa_i = std::numeric_limits<float>::lowest();
for (int q_id = 0; q_id < q_len; q_id++) {
q_i += GGML_FP16_TO_FP32(
q_in_data[batch_id * q_len *
config_.q_head_num *
config_.head_dim +
q_id * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim + i]);
}
q_i /= q_len;
for (int anchor_id = 0; anchor_id < config_.anchor_num;
anchor_id++) {
qa_i = std::max(
qa_i,
GGML_FP16_TO_FP32(
anchor_[(long long)layer_idx *
config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num *
config_.head_dim +
anchor_id * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim + i]) *
q_i);
}
sim += qa_i;
}
}
block_similar_[batch_id][block_id] = sim;
},
nullptr);
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of calculating similarity: %f s\n", layer_idx,
// diff.count());
}
void KVCache::select_block_layer_(int batch_size, int layer_idx,
int max_block_num, int init_block_num,
int local_block_num, int pick_block_num) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
if (cache_seqlens_[batch_idx] / config_.block_len <=
init_block_num + pick_block_num + local_block_num) {
block_table_after_retrieval_[batch_idx].swap(
block_table_before_retrieval_[batch_idx]);
selected_blocks_num_history_[(layer_idx - config_.layer_offset) /
config_.layer_step] = 0;
continue;
}
for (int block_id = init_block_num;
block_id <
(cache_seqlens_[batch_idx] / config_.block_len) - local_block_num;
block_id++) {
top_similar_block_[batch_idx].push(std::make_pair(
block_similar_[batch_idx][block_id],
block_table_before_retrieval_[batch_idx][block_id]));
if (top_similar_block_[batch_idx].size() > pick_block_num) {
top_similar_block_[batch_idx].pop();
}
}
int i = 0;
for (; i < init_block_num; i++) {
block_table_after_retrieval_[batch_idx][i] =
block_table_before_retrieval_[batch_idx][i];
}
while (!top_similar_block_[batch_idx].empty()) {
block_table_after_retrieval_[batch_idx][i] =
top_similar_block_[batch_idx].top().second;
top_similar_block_[batch_idx].pop();
i++;
}
for (; i < init_block_num + pick_block_num + local_block_num; i++) {
block_table_after_retrieval_[batch_idx][i] =
block_table_before_retrieval_[batch_idx]
[(cache_seqlens_[batch_idx] /
config_.block_len) -
local_block_num + i -
init_block_num - pick_block_num];
}
if (cache_seqlens_[batch_idx] % config_.block_len != 0) {
block_table_after_retrieval_[batch_idx][i] =
block_table_before_retrieval_[batch_idx][(
cache_seqlens_[batch_idx] / config_.block_len)];
cache_seqlens_[batch_idx] =
(cache_seqlens_[batch_idx] % config_.block_len) +
i * config_.block_len;
i++;
} else {
cache_seqlens_[batch_idx] =
(cache_seqlens_[batch_idx] % config_.block_len) +
i * config_.block_len;
}
for (int j = 0; j < i; j++) {
selected_blocks_history_[(layer_idx - config_.layer_offset) /
config_.layer_step][batch_idx][j] =
block_table_after_retrieval_[batch_idx][j];
}
selected_blocks_num_history_[(layer_idx - config_.layer_offset) /
config_.layer_step] = i;
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of selecting blocks: %f s\n", layer_idx,
// diff.count());
}
// retrieval kvcache, get the init_block_num block at beginning, top
// pick_block_num similar and last local_block_num blocks. Each task
// calculates the simlarity of a certain block with the query, then push
// the block into the priority queue. Finally, the required blocks are
// pushed into the block_table_after_retrieval_.
void KVCache::retrieval_kvcache_layer_(const uint16_t *q_in_data,
int init_block_num, int local_block_num,
int pick_block_num, int q_len,
int generate_token_idx, int batch_size,
int layer_idx, int *cache_seqlens,
int &max_block_num, Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
max_block_num_after_retrieval_ = 0;
if (pick_block_num != -1 &&
(generate_token_idx % config_.token_step != 0 ||
(layer_idx % config_.layer_step != config_.layer_offset))) {
if (selected_blocks_num_history_[(layer_idx - config_.layer_offset) /
config_.layer_step] == 0) {
max_block_num_after_retrieval_ = max_block_num;
block_table_after_retrieval_.swap(block_table_before_retrieval_);
} else {
max_block_num_after_retrieval_ = selected_blocks_num_history_
[(layer_idx - config_.layer_offset) / config_.layer_step];
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int i = 0; i < max_block_num_after_retrieval_; i++) {
block_table_after_retrieval_[batch_idx][i] =
selected_blocks_history_[(layer_idx -
config_.layer_offset) /
config_.layer_step][batch_idx]
[i];
}
if (cache_seqlens[batch_idx] % config_.block_len == 1) {
selected_blocks_num_history_[(layer_idx -
config_.layer_offset) /
config_.layer_step] += 1;
int x =
selected_blocks_num_history_[(layer_idx -
config_.layer_offset) /
config_.layer_step];
int last_block_idx =
block_table_before_retrieval_[batch_idx]
[cache_seqlens[batch_idx] /
config_.block_len];
selected_blocks_history_[(layer_idx -
config_.layer_offset) /
config_.layer_step][batch_idx]
[x - 1] = last_block_idx;
block_table_after_retrieval_[batch_idx][x - 1] =
last_block_idx;
}
cache_seqlens_[batch_idx] =
(cache_seqlens_[batch_idx] % config_.block_len) +
selected_blocks_num_history_[(layer_idx -
config_.layer_offset) /
config_.layer_step] *
config_.block_len -
config_.block_len;
}
}
} else if (pick_block_num != -1) {
max_block_num_after_retrieval_ =
std::min(max_block_num,
init_block_num + pick_block_num + local_block_num + 1);
calculate_block_similarity_layer_(q_in_data, batch_size, layer_idx,
q_len, max_block_num, cache_seqlens,
init_block_num, local_block_num,
pick_block_num, backend);
select_block_layer_(batch_size, layer_idx, max_block_num,
init_block_num, local_block_num, pick_block_num);
} else {
selected_blocks_num_history_[(layer_idx - config_.layer_offset) /
config_.layer_step] = 0;
max_block_num_after_retrieval_ = max_block_num;
block_table_after_retrieval_.swap(block_table_before_retrieval_);
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
// printf("layer %d time of retrieval kvcache: %f s\n", layer_idx,
// std::chrono::duration<double>(end - start).count());
}
void KVCache::calculate_sparsity_layer_(const uint16_t *q_in_data,
float *attn_sparsity, int batch_size,
int max_block_num, int *block_table,
int *cache_seqlens, Backend *backend
) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
seq_len_ = config_.block_len;
backend->do_work_stealing_job(
batch_size * config_.kv_head_num * max_block_num,
[&](int thread_id) {
thread_cur_head_idx_[thread_id].first = -1;
thread_cur_head_idx_[thread_id].second = -1;
},
[&](int task_id) {
int batch_id = task_id / (config_.kv_head_num * max_block_num);
int head_id = (task_id % (config_.kv_head_num * max_block_num)) /
max_block_num;
int block_id = task_id % max_block_num;
int thread_id = Backend::thread_local_id;
// If the block is out of the sequence length, skip it.
if (cache_seqlens[batch_id] / config_.block_len < block_id) {
return;
}
int block_idx = block_table[batch_id * max_block_num + block_id];
if (cache_seqlens_[batch_id] / config_.block_len == block_id) {
int seq_len = cache_seqlens_[batch_id] % config_.block_len;
if (seq_len == 0)
return;
// Prepare the attention mask for the last block.
int full_blocks = seq_len / 8;
int remaining_bits = seq_len % 8;
// Fill full blocks with 1s
for (int i = 0; i < full_blocks; ++i) {
thread_local_attn_mask_[thread_id][i] = 0xFF;
}
// Fill the remaining bits in the next block
if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {
thread_local_attn_mask_[thread_id][full_blocks] =
(1 << remaining_bits) - 1;
} else {
thread_local_attn_mask_[thread_id][full_blocks] = 0;
}
for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {
thread_local_attn_mask_[thread_id][i] = 0;
}
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
(void *)&q_in_data[batch_id * config_.kv_head_num *
n_gqa_ * config_.head_dim +
head_id * n_gqa_ * config_.head_dim],
seq_len_, 0, false,
thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_F16, 0,
k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_F16, 1,
v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num,
GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
seq_len_, 0, false,
thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_Q4_0, 0,
k_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_Q4_0, 1,
v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
dequantize_row_q8_0(
thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
n_gqa_ * config_.head_dim);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num,
GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
seq_len_, 0, false,
thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_Q8_0, 0,
k_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_Q8_0, 1,
v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
dequantize_row_q8_0(
thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
n_gqa_ * config_.head_dim);
}
} else {
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
(void *)&q_in_data[batch_id * config_.kv_head_num *
n_gqa_ * config_.head_dim +
head_id * n_gqa_ * config_.head_dim],
seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0,
k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_F16, 1,
v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num,
GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0,
k_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_Q4_0, 1,
v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
dequantize_row_q8_0(
thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
n_gqa_ * config_.head_dim);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num,
GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0,
k_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_Q8_0, 1,
v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
dequantize_row_q8_0(
thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
n_gqa_ * config_.head_dim);
}
}
for (int i = 0; i < n_gqa_; i++) {
block_lse_[batch_id][block_idx][head_id * n_gqa_ + i] =
thread_local_attn_lse_[thread_id][i];
}
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
int cur_head_id = thread_cur_head_idx_[thread_id].second;
if (batch_id == cur_batch_idx && head_id == cur_head_id) {
for (int i = 0; i < n_gqa_; i++) {
float new_attn_lse =
thread_local_cur_attn_lse_[thread_id][i] +
std::log(
1.0 +
std::exp(thread_local_attn_lse_[thread_id][i] -
thread_local_cur_attn_lse_[thread_id][i]));
ggml_vec_scale_f32(
config_.head_dim,
thread_local_cur_output_fp32_[thread_id].data() +
i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] -
new_attn_lse));
ggml_vec_scale_f32(
config_.head_dim,
thread_local_output_fp32_[thread_id].data() +
i * config_.head_dim,
std::exp(thread_local_attn_lse_[thread_id][i] -
new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
thread_local_cur_output_fp32_[thread_id]
[i * config_.head_dim +
j] +=
thread_local_output_fp32_[thread_id]
[i * config_.head_dim + j];
}
thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;
}
} else {
if (cur_batch_idx != -1) {
mutex_[cur_batch_idx][cur_head_id]->lock();
for (int i = 0; i < n_gqa_; i++) {
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <
1e-6) {
attn_lse_[cur_batch_idx][cur_head_id][i] =
thread_local_cur_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id]
[i * config_.head_dim + j] =
thread_local_cur_output_fp32_
[thread_id]
[i * config_.head_dim + j];
}
continue;
}
float new_attn_lse =
attn_lse_[cur_batch_idx][cur_head_id][i] +
std::log(
1.0 +
std::exp(
thread_local_cur_attn_lse_[thread_id][i] -
attn_lse_[cur_batch_idx][cur_head_id][i]));
ggml_vec_scale_f32(
config_.head_dim,
output_fp32_[cur_batch_idx][cur_head_id].data() +
i * config_.head_dim,
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -
new_attn_lse));
ggml_vec_scale_f32(
config_.head_dim,
thread_local_cur_output_fp32_[thread_id].data() +
i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] -
new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id]
[i * config_.head_dim + j] +=
thread_local_cur_output_fp32_
[thread_id][i * config_.head_dim + j];
}
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
}
mutex_[cur_batch_idx][cur_head_id]->unlock();
}
thread_cur_head_idx_[thread_id].first = batch_id;
thread_cur_head_idx_[thread_id].second = head_id;
for (int i = 0; i < n_gqa_; i++) {
thread_local_cur_attn_lse_[thread_id][i] =
thread_local_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
thread_local_cur_output_fp32_
[thread_id][i * config_.head_dim + j] =
thread_local_output_fp32_[thread_id]
[i * config_.head_dim +
j];
}
}
}
},
// Merge the results of the remaining blocks.
[&](int thread_id) {
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
int cur_head_id = thread_cur_head_idx_[thread_id].second;
if (cur_head_id != -1) {
mutex_[cur_batch_idx][cur_head_id]->lock();
for (int i = 0; i < n_gqa_; i++) {
float new_attn_lse;
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <
1e-6) {
attn_lse_[cur_batch_idx][cur_head_id][i] =
thread_local_cur_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id]
[i * config_.head_dim + j] =
thread_local_cur_output_fp32_
[thread_id]
[i * config_.head_dim + j];
}
continue;
}
new_attn_lse =
attn_lse_[cur_batch_idx][cur_head_id][i] +
std::log(
1.0 +
std::exp(thread_local_cur_attn_lse_[thread_id][i] -
attn_lse_[cur_batch_idx][cur_head_id][i]));
ggml_vec_scale_f32(
config_.head_dim,
output_fp32_[cur_batch_idx][cur_head_id].data() +
i * config_.head_dim,
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -
new_attn_lse));
ggml_vec_scale_f32(
config_.head_dim,
thread_local_cur_output_fp32_[thread_id].data() +
i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] -
new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id]
[i * config_.head_dim + j] +=
thread_local_cur_output_fp32_[thread_id]
[i * config_.head_dim +
j];
}
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
}
mutex_[cur_batch_idx][cur_head_id]->unlock();
}
});
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < max_block_num_after_retrieval_; j++) {
int block_idx = block_table_after_retrieval_[i][j];
for (int k = 0; k < config_.q_head_num; k++) {
attn_sparsity[i * config_.q_head_num + k] +=
std::exp(block_lse_[i][block_idx][k] -
attn_lse_[i][k / n_gqa_][k % n_gqa_]);
}
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of calculating sparsity: %f s\n", layer_id_,
// diff.count());
}
void KVCache::attn_initialize_kvhead_(int batch_size, int layer_idx,
int *block_table, int &max_block_num,
int *cache_seqlens) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
// initialize output_fp32_ and attn_lse_
for (int i = 0; i < config_.kv_head_num; i++) {
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
output_fp32_[batch_idx][i][j] = 0;
}
for (int j = 0; j < n_gqa_; j++) {
attn_lse_[batch_idx][i][j] = 0;
}
}
// clear top_similar_block_
while (!top_similar_block_[batch_idx].empty())
top_similar_block_[batch_idx].pop();
}
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
cache_seqlens_[batch_idx] = cache_seqlens[batch_idx];
for (int i = 0; i < max_block_num; i++) {
for (int j = 0; j < config_.kv_head_num; j++) {
block_table_before_retrieval_kvhead_[batch_idx][i][j] =
block_table[batch_idx * max_block_num + i];
block_similar_kv_head_[batch_idx][i][j] = 0;
}
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
// printf("layer %d time of initializing attn: %f s\n", layer_idx,
// std::chrono::duration<double>(end - start).count());
}
void KVCache::retrieval_kvcache_kvhead_(const uint16_t *q_in_data,
int init_block_num, int local_block_num,
int pick_block_num, int q_len,
int generate_token_idx, int batch_size,
int layer_idx, int *cache_seqlens,
int &max_block_num, Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
max_block_num_after_retrieval_ = 0;
if (pick_block_num != -1 &&
(generate_token_idx % config_.token_step != 0 ||
(layer_idx % config_.layer_step != config_.layer_offset))) {
if (selected_blocks_num_history_[(layer_idx - config_.layer_offset) /
config_.layer_step] == 0) {
max_block_num_after_retrieval_ = max_block_num;
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int i = 0; i < max_block_num; i++) {
for (int j = 0; j < config_.kv_head_num; j++) {
block_table_after_retrieval_kvhead_[batch_idx][i][j] =
block_table_before_retrieval_kvhead_[batch_idx][i]
[j];
}
}
}
} else {
max_block_num_after_retrieval_ = selected_blocks_num_history_
[(layer_idx - config_.layer_offset) / config_.layer_step];
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int i = 0; i < max_block_num_after_retrieval_; i++) {
for (int j = 0; j < config_.kv_head_num; j++) {
block_table_after_retrieval_kvhead_[batch_idx][i][j] =
selected_blocks_history_kvhead_
[(layer_idx - config_.layer_offset) /
config_.layer_step][batch_idx][i][j];
}
}
if (cache_seqlens[batch_idx] % config_.block_len == 1) {
selected_blocks_num_history_[(layer_idx -
config_.layer_offset) /
config_.layer_step] += 1;
int x =
selected_blocks_num_history_[(layer_idx -
config_.layer_offset) /
config_.layer_step];
for (int i = 0; i < config_.kv_head_num; i++) {
int last_block_idx =
block_table_before_retrieval_kvhead_
[batch_idx][cache_seqlens[batch_idx] /
config_.block_len][i];
selected_blocks_history_kvhead_[(layer_idx -
config_.layer_offset) /
config_.layer_step]
[batch_idx][x - 1][i] =
last_block_idx;
block_table_after_retrieval_kvhead_[batch_idx][x - 1]
[i] = last_block_idx;
}
}
cache_seqlens_[batch_idx] = std::min(
cache_seqlens_[batch_idx],
(cache_seqlens_[batch_idx] % config_.block_len) +
(init_block_num + pick_block_num + local_block_num) *
config_.block_len);
}
}
} else if (pick_block_num != -1) {
max_block_num_after_retrieval_ =
std::min(max_block_num,
init_block_num + pick_block_num + local_block_num + 1);
calculate_block_similarity_kvhead_(q_in_data, batch_size, layer_idx,
q_len, max_block_num, cache_seqlens,
init_block_num, local_block_num,
pick_block_num, backend);
select_block_kvhead_(batch_size, layer_idx, max_block_num,
init_block_num, local_block_num, pick_block_num);
} else {
selected_blocks_num_history_[(layer_idx - config_.layer_offset) /
config_.layer_step] = 0;
max_block_num_after_retrieval_ = max_block_num;
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int i = 0; i < max_block_num; i++) {
for (int j = 0; j < config_.kv_head_num; j++) {
block_table_after_retrieval_kvhead_[batch_idx][i][j] =
block_table_before_retrieval_kvhead_[batch_idx][i][j];
}
}
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
// printf("layer %d time of retrieval kvcache: %f s\n", layer_idx,
// std::chrono::duration<double>(end - start).count());
}
void KVCache::calculate_sparsity_kvhead_(const uint16_t *q_in_data,
float *attn_sparsity, int batch_size,
int max_block_num, int *block_table,
int *cache_seqlens, Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
seq_len_ = config_.block_len;
backend->do_work_stealing_job(
batch_size * config_.kv_head_num * max_block_num,
[&](int thread_id) {
thread_cur_head_idx_[thread_id].first = -1;
thread_cur_head_idx_[thread_id].second = -1;
},
[&](int task_id) {
int batch_id = task_id / (config_.kv_head_num * max_block_num);
int head_id = (task_id % (config_.kv_head_num * max_block_num)) /
max_block_num;
int block_id = task_id % max_block_num;
int thread_id = Backend::thread_local_id;
// If the block is out of the sequence length, skip it.
if (cache_seqlens[batch_id] / config_.block_len < block_id) {
return;
}
int block_idx = block_table[batch_id * max_block_num + block_id];
if (cache_seqlens_[batch_id] / config_.block_len == block_id) {
int seq_len = cache_seqlens_[batch_id] % config_.block_len;
if (seq_len == 0)
return;
// Prepare the attention mask for the last block.
int full_blocks = seq_len / 8;
int remaining_bits = seq_len % 8;
// Fill full blocks with 1s
for (int i = 0; i < full_blocks; ++i) {
thread_local_attn_mask_[thread_id][i] = 0xFF;
}
// Fill the remaining bits in the next block
if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {
thread_local_attn_mask_[thread_id][full_blocks] =
(1 << remaining_bits) - 1;
} else {
thread_local_attn_mask_[thread_id][full_blocks] = 0;
}
for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {
thread_local_attn_mask_[thread_id][i] = 0;
}
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
(void *)&q_in_data[batch_id * config_.kv_head_num *
n_gqa_ * config_.head_dim +
head_id * n_gqa_ * config_.head_dim],
seq_len_, 0, false,
thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_F16, 0,
k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_F16, 1,
v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num,
GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
seq_len_, 0, false,
thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_Q4_0, 0,
k_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_Q4_0, 1,
v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
dequantize_row_q8_0(
thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
n_gqa_ * config_.head_dim);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num,
GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
seq_len_, 0, false,
thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_Q8_0, 0,
k_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_Q8_0, 1,
v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
dequantize_row_q8_0(
thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
n_gqa_ * config_.head_dim);
}
} else {
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
(void *)&q_in_data[batch_id * config_.kv_head_num *
n_gqa_ * config_.head_dim +
head_id * n_gqa_ * config_.head_dim],
seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0,
k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_F16, 1,
v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num,
GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0,
k_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_Q4_0, 1,
v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
dequantize_row_q8_0(
thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
n_gqa_ * config_.head_dim);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
attn_with_kvcache_one_block_(
config_.head_dim,
config_.q_head_num / config_.kv_head_num,
GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),
seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0,
k_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_Q8_0, 1,
v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr,
thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr,
cos_.data(), sin_.data());
dequantize_row_q8_0(
thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
n_gqa_ * config_.head_dim);
}
}
for (int i = 0; i < n_gqa_; i++) {
block_lse_[batch_id][block_idx][head_id * n_gqa_ + i] =
thread_local_attn_lse_[thread_id][i];
}
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
int cur_head_id = thread_cur_head_idx_[thread_id].second;
if (batch_id == cur_batch_idx && head_id == cur_head_id) {
for (int i = 0; i < n_gqa_; i++) {
float new_attn_lse =
thread_local_cur_attn_lse_[thread_id][i] +
std::log(
1.0 +
std::exp(thread_local_attn_lse_[thread_id][i] -
thread_local_cur_attn_lse_[thread_id][i]));
ggml_vec_scale_f32(
config_.head_dim,
thread_local_cur_output_fp32_[thread_id].data() +
i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] -
new_attn_lse));
ggml_vec_scale_f32(
config_.head_dim,
thread_local_output_fp32_[thread_id].data() +
i * config_.head_dim,
std::exp(thread_local_attn_lse_[thread_id][i] -
new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
thread_local_cur_output_fp32_[thread_id]
[i * config_.head_dim +
j] +=
thread_local_output_fp32_[thread_id]
[i * config_.head_dim + j];
}
thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;
}
} else {
if (cur_batch_idx != -1) {
mutex_[cur_batch_idx][cur_head_id]->lock();
for (int i = 0; i < n_gqa_; i++) {
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <
1e-6) {
attn_lse_[cur_batch_idx][cur_head_id][i] =
thread_local_cur_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id]
[i * config_.head_dim + j] =
thread_local_cur_output_fp32_
[thread_id]
[i * config_.head_dim + j];
}
continue;
}
float new_attn_lse =
attn_lse_[cur_batch_idx][cur_head_id][i] +
std::log(
1.0 +
std::exp(
thread_local_cur_attn_lse_[thread_id][i] -
attn_lse_[cur_batch_idx][cur_head_id][i]));
ggml_vec_scale_f32(
config_.head_dim,
output_fp32_[cur_batch_idx][cur_head_id].data() +
i * config_.head_dim,
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -
new_attn_lse));
ggml_vec_scale_f32(
config_.head_dim,
thread_local_cur_output_fp32_[thread_id].data() +
i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] -
new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id]
[i * config_.head_dim + j] +=
thread_local_cur_output_fp32_
[thread_id][i * config_.head_dim + j];
}
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
}
mutex_[cur_batch_idx][cur_head_id]->unlock();
}
thread_cur_head_idx_[thread_id].first = batch_id;
thread_cur_head_idx_[thread_id].second = head_id;
for (int i = 0; i < n_gqa_; i++) {
thread_local_cur_attn_lse_[thread_id][i] =
thread_local_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
thread_local_cur_output_fp32_
[thread_id][i * config_.head_dim + j] =
thread_local_output_fp32_[thread_id]
[i * config_.head_dim +
j];
}
}
}
},
// Merge the results of the remaining blocks.
[&](int thread_id) {
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
int cur_head_id = thread_cur_head_idx_[thread_id].second;
if (cur_head_id != -1) {
mutex_[cur_batch_idx][cur_head_id]->lock();
for (int i = 0; i < n_gqa_; i++) {
float new_attn_lse;
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <
1e-6) {
attn_lse_[cur_batch_idx][cur_head_id][i] =
thread_local_cur_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id]
[i * config_.head_dim + j] =
thread_local_cur_output_fp32_
[thread_id]
[i * config_.head_dim + j];
}
continue;
}
new_attn_lse =
attn_lse_[cur_batch_idx][cur_head_id][i] +
std::log(
1.0 +
std::exp(thread_local_cur_attn_lse_[thread_id][i] -
attn_lse_[cur_batch_idx][cur_head_id][i]));
ggml_vec_scale_f32(
config_.head_dim,
output_fp32_[cur_batch_idx][cur_head_id].data() +
i * config_.head_dim,
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -
new_attn_lse));
ggml_vec_scale_f32(
config_.head_dim,
thread_local_cur_output_fp32_[thread_id].data() +
i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] -
new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id]
[i * config_.head_dim + j] +=
thread_local_cur_output_fp32_[thread_id]
[i * config_.head_dim +
j];
}
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
}
mutex_[cur_batch_idx][cur_head_id]->unlock();
}
});
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < max_block_num_after_retrieval_; j++) {
for (int k = 0; k < config_.q_head_num; k++) {
int block_idx =
block_table_after_retrieval_kvhead_[i][j][k / n_gqa_];
attn_sparsity[i * config_.q_head_num + k] +=
std::exp(block_lse_[i][block_idx][k] -
attn_lse_[i][k / n_gqa_][k % n_gqa_]);
}
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of calculating sparsity: %f s\n", layer_id_,
// diff.count());
}
void KVCache::calculate_block_similarity_kvhead_(
const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len,
int max_block_num, int *cache_seqlens, int init_block_num,
int local_block_num, int pick_block_num, Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
backend->do_work_stealing_job(
batch_size * max_block_num, nullptr,
[&](int task_id) {
int batch_id = task_id / max_block_num;
int block_id = task_id % max_block_num;
int seq_len = cache_seqlens_[batch_id];
if (block_id < init_block_num ||
block_id >= (seq_len / config_.block_len) - local_block_num) {
return;
}
int block_idx =
block_table_before_retrieval_kvhead_[batch_id][block_id][0];
for (int head_id = 0; head_id < config_.q_head_num; head_id++) {
for (int i = 0; i < config_.head_dim; i++) {
float q_i = 0, qa_i = std::numeric_limits<float>::lowest();
for (int q_id = 0; q_id < q_len; q_id++) {
q_i += GGML_FP16_TO_FP32(
q_in_data[batch_id * q_len * config_.q_head_num *
config_.head_dim +
q_id * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim + i]);
}
q_i /= q_len;
for (int anchor_id = 0; anchor_id < config_.anchor_num;
anchor_id++) {
qa_i = std::max(
qa_i,
GGML_FP16_TO_FP32(
anchor_[layer_idx * config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num *
config_.head_dim +
anchor_id * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim + i]) *
q_i);
}
block_similar_kv_head_[batch_id][block_id]
[head_id / n_gqa_] += qa_i;
}
}
},
nullptr);
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of calculating similarity: %f s\n", layer_idx,
// diff.count());
}
void KVCache::select_block_kvhead_(int batch_size, int layer_idx,
int max_block_num, int init_block_num,
int local_block_num, int pick_block_num) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
int cache_len_after_retrieval = 0;
if (cache_seqlens_[batch_idx] / config_.block_len <=
init_block_num + pick_block_num + local_block_num) {
selected_blocks_num_history_[(layer_idx - config_.layer_offset) /
config_.layer_step] = 0;
for (int i = 0; i < max_block_num; i++) {
for (int j = 0; j < config_.kv_head_num; j++) {
block_table_after_retrieval_kvhead_[batch_idx][i][j] =
block_table_before_retrieval_kvhead_[batch_idx][i][j];
}
}
continue;
}
for (int head_id = 0; head_id < config_.kv_head_num; head_id++) {
for (int block_id = init_block_num;
block_id < (cache_seqlens_[batch_idx] / config_.block_len) -
local_block_num;
block_id++) {
top_similar_block_[batch_idx].push(std::make_pair(
block_similar_kv_head_[batch_idx][block_id][head_id],
block_table_before_retrieval_kvhead_[batch_idx][block_id]
[head_id]));
if (top_similar_block_[batch_idx].size() > pick_block_num) {
top_similar_block_[batch_idx].pop();
}
}
int i = 0;
for (; i < init_block_num; i++) {
block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =
block_table_before_retrieval_kvhead_[batch_idx][i][head_id];
}
while (!top_similar_block_[batch_idx].empty()) {
block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =
top_similar_block_[batch_idx].top().second;
top_similar_block_[batch_idx].pop();
i++;
}
for (; i < init_block_num + pick_block_num + local_block_num; i++) {
block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =
block_table_before_retrieval_kvhead_
[batch_idx]
[(cache_seqlens_[batch_idx] / config_.block_len) -
local_block_num + i - init_block_num - pick_block_num]
[head_id];
}
if (cache_seqlens_[batch_idx] % config_.block_len != 0) {
block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =
block_table_before_retrieval_kvhead_[batch_idx][(
cache_seqlens_[batch_idx] / config_.block_len)]
[head_id];
cache_len_after_retrieval =
(cache_seqlens_[batch_idx] % config_.block_len) +
i * config_.block_len;
i++;
} else {
cache_len_after_retrieval =
(cache_seqlens_[batch_idx] % config_.block_len) +
i * config_.block_len;
}
for (int j = 0; j < i; j++) {
selected_blocks_history_kvhead_
[(layer_idx - config_.layer_offset) / config_.layer_step]
[batch_idx][j][head_id] =
block_table_after_retrieval_kvhead_[batch_idx][j]
[head_id];
}
}
cache_seqlens_[batch_idx] = cache_len_after_retrieval;
selected_blocks_num_history_[(layer_idx - config_.layer_offset) /
config_.layer_step] =
(cache_len_after_retrieval + config_.block_len - 1) /
config_.block_len;
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of selecting block: %f s\n", layer_idx,
// diff.count())
}
void KVCache::get_attn_sparsity(const ggml_fp16_t *q_in, float *attn_sparsity,
int layer_idx, int generate_token_idx,
int q_len, int batch_size, int max_block_num,
int *block_table, int *cache_seqlens,
int *block_table_origin,
int *cache_seqlens_origin,
int max_block_num_origin, int topk, int local,
Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_idx;
int thread_num = backend->get_thread_num();
batch_size = 1;
const uint16_t *q_in_data = const_cast<const uint16_t *>(q_in);
quantize_q_(q_in_data, batch_size);
if (config_.retrieval_type == RetrievalType::LAYER) {
attn_initialize_layer_(batch_size, layer_idx, block_table,
max_block_num, cache_seqlens);
retrieval_kvcache_layer_(q_in_data, 1, local, topk, q_len,
generate_token_idx, batch_size, layer_idx,
cache_seqlens, max_block_num, backend);
calculate_sparsity_layer_(q_in_data, attn_sparsity, batch_size,
max_block_num_origin, block_table_origin,
cache_seqlens_origin, backend);
} else if (config_.retrieval_type == RetrievalType::KVHEAD) {
attn_initialize_kvhead_(batch_size, layer_idx, block_table,
max_block_num, cache_seqlens);
retrieval_kvcache_kvhead_(q_in_data, 1, local, topk, q_len,
generate_token_idx, batch_size, layer_idx,
cache_seqlens, max_block_num, backend);
calculate_sparsity_kvhead_(q_in_data, attn_sparsity, batch_size,
max_block_num_origin, block_table_origin,
cache_seqlens_origin, backend);
}
}
void KVCache::attn_with_kvcache_one_block_(
int head_dim, int bsz,
ggml_type q_type, // GGML data type of `Q`, only supports fp16 and q8_0
// [bsz, head_dim]
// Quantization is always on the head_dim dimension (per_token). If
// head_dim % 32 != 0, an error will be raised. The size must be bsz *
// head_dim/32 * qtype_size.
const void *q,
int past_kv_len, int past_kv_offset,
bool is_full_attn, // true indicates a full 1 mask
// If is_full_attn = false, a bit matrix representing the mask is
// passed. [bsz, past_kv_len]
const uint8_t *attn_mask,
ggml_type k_type, // GGML data type of `K Cache`, only supports fp16,
// q4_0, q8_0
int k_quant_type, // 0 for per_token, 1 for per_channel, others raise an
// error
// [seq_len, head_dim]
// If quant_type == 0, head_dim % 32 must be 0.
// If quant_type == 1, seq_len % 32 must be 0.
const void *k_cache,
// k_anchor_type must be fp16
int num_k_anchor, // num_k_anchor == 0 indicates no anchor
// [num_k_anchor, head_dim]
const void *k_cache_anchors,
// Each token is associated with the nearest previous position's anchor,
// with the same distance.
const int *k_cache_anchor_pos,
// v_cache similar to k_cache
ggml_type v_type, int v_quant_type,
// [head_dim, seq_len]
const void *v_cache, int num_v_anchor, const void *v_cache_anchors,
const int *v_cache_anchor_pos,
// Pre-allocated buffer for intermediate calculations [bsz,
// past_kv_len]. No malloc is performed inside this function.
float *attn_score,
// Output: [bsz, head_dim], with the same type as q_type
void *output,
// [bsz]
float *lse,
// Pre-allocated temporary buffer with sufficient size:
// (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 * past_kv_len *
// head_dim + past_kv_len * head_dim / 32) bytes.
void *draft,
// Apply rotary embedding online
const int *rotary_angle, const void *rotary_cos, const void *rotary_sin
// rotary_cos=None,
// rotary_sin=None,
// cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
// cache_batch_idx: Optional[torch.Tensor] = None,
// rotary_interleaved=True,
// // Not supported for now
// window_size=(-1, -1), # -1 means infinite context window
// alibi_slopes=None,
) {
assert(head_dim % 32 == 0);
assert(k_quant_type == 0);
assert(v_quant_type == 1);
assert(q_type == GGML_TYPE_F16 || q_type == GGML_TYPE_Q8_0);
if (q_type == GGML_TYPE_F16) {
assert(k_type == GGML_TYPE_F16);
assert(v_type == GGML_TYPE_F16);
// attn = q * k + q * k_anchor
// TODO: anchor
assert(num_k_anchor == 0);
if (rotary_angle != nullptr) {
ggml_fp16_t *k_cache_with_rope_fp16 =
(reinterpret_cast<ggml_fp16_t *>(draft) +
sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +
sizeof(float) * bsz * head_dim);
// dequant k_cache and apply rope
// k_rope(i) = k(i) * cos(i) - k(i+l) * sin(i)
// k_rope(i+l) = k(i+l) * cos(i+l) + k(i) * sin(i)
// k(i)cos(i) -> k_rope(i)
// k(i)sin(i+l) -> k_rope(i+l)
// k(i)cos(i) -> k_rope(i)
// -k(i)sin(i-l) -> k_rope(i-l)
std::vector<float> block_fp32(32);
for (int k = 0; k < past_kv_len; k++) {
int angle = rotary_angle[k];
for (int l = 0; l < head_dim / 32; l++) {
for (int m = 0; m < 32; m++) {
float x = GGML_FP16_TO_FP32((
(ggml_fp16_t *)k_cache)[k * head_dim + l * 32 + m]);
float sin_val = GGML_FP16_TO_FP32(
((ggml_fp16_t *)
rotary_sin)[angle * head_dim + l * 32 + m]);
float cos_val = GGML_FP16_TO_FP32(
((ggml_fp16_t *)
rotary_cos)[angle * head_dim + l * 32 + m]);
if (l * 32 + m < head_dim / 2) {
k_cache_with_rope_fp16[k * head_dim + l * 32 + m] =
GGML_FP32_TO_FP16(x * cos_val);
k_cache_with_rope_fp16[k * head_dim + l * 32 + m +
head_dim / 2] =
GGML_FP32_TO_FP16(-x * sin_val);
} else {
k_cache_with_rope_fp16[k * head_dim + l * 32 + m] =
GGML_FP32_TO_FP16(
GGML_FP16_TO_FP32(
k_cache_with_rope_fp16[k * head_dim +
l * 32 + m]) +
x * sin_val);
k_cache_with_rope_fp16[k * head_dim + l * 32 + m -
head_dim / 2] =
GGML_FP32_TO_FP16(
GGML_FP16_TO_FP32(
k_cache_with_rope_fp16[k * head_dim +
l * 32 + m -
head_dim / 2]) -
x * cos_val);
}
}
}
}
llamafile_sgemm(past_kv_len, bsz, head_dim,
(ggml_fp16_t *)k_cache_with_rope_fp16, head_dim,
(ggml_fp16_t *)q, head_dim, attn_score, past_kv_len,
0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_F16,
GGML_TYPE_F32, GGML_PREC_DEFAULT);
} else {
bool ok = llamafile_sgemm(
past_kv_len, bsz, head_dim, (ggml_fp16_t *)k_cache, head_dim,
(ggml_fp16_t *)q, head_dim, attn_score, past_kv_len, 0, 1,
GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_F16, GGML_TYPE_F32,
GGML_PREC_DEFAULT);
if (!ok) {
printf("llamafile_sgemm failed\n");
}
}
// attn = attn * scale
float scale_factor = 1.0 / std::sqrt(float(head_dim));
ggml_vec_scale_f32(bsz * past_kv_len, attn_score, scale_factor);
// attn = attn & mask
if (!is_full_attn) {
for (int i = 0; i < bsz; i++) {
for (int j = 0; j < past_kv_len; j++) {
int index = i * past_kv_len + j;
if (!(attn_mask[j / 8] & (1 << (j % 8)))) {
attn_score[index] =
std::numeric_limits<float>::lowest();
}
}
}
}
// attn = softmax(attn)
for (int i = 0; i < bsz; i++) {
float sum_exp = 0;
for (int j = 0; j < past_kv_len; j++) {
attn_score[i * past_kv_len + j] =
std::exp(attn_score[i * past_kv_len + j]);
sum_exp += attn_score[i * past_kv_len + j];
}
for (int j = 0; j < past_kv_len; j++) {
attn_score[i * past_kv_len + j] /= sum_exp;
}
if (lse != nullptr) {
lse[i] = std::log(sum_exp);
}
}
// output = attn * v + attn * v_anchor
// std::vector<float> sum(bsz * head_dim);
float *sum = reinterpret_cast<float *>(reinterpret_cast<char *>(draft) +
sizeof(block_q8_0) * bsz *
past_kv_len / QK8_0);
// float* attn_score_fp16(bsz, past_kv_len)
ggml_fp16_t *attn_score_fp16 = (reinterpret_cast<ggml_fp16_t *>(
reinterpret_cast<char *>(draft) +
sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +
sizeof(float) * bsz * head_dim));
for (int i = 0; i < bsz * past_kv_len; i++) {
attn_score_fp16[i] = GGML_FP32_TO_FP16(attn_score[i]);
}
// TODO: anchor
assert(num_v_anchor == 0);
bool ok = llamafile_sgemm(
head_dim, bsz, past_kv_len, (ggml_fp16_t *)v_cache, past_kv_len,
(ggml_fp16_t *)attn_score_fp16, past_kv_len, sum, head_dim, 0, 1,
GGML_TASK_TYPE_COMPUTE, v_type, GGML_TYPE_F16, GGML_TYPE_F32,
GGML_PREC_DEFAULT);
if (!ok) {
printf("llamafile_sgemm failed\n");
}
// copy to output
for (int i = 0; i < bsz; i++) {
for (int j = 0; j < head_dim; j++) {
((float *)output)[i * head_dim + j] = sum[i * head_dim + j];
}
}
} else {
assert(k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q8_0);
assert(v_type == GGML_TYPE_Q4_0 || v_type == GGML_TYPE_Q8_0);
// attn = q * k + q * k_anchor
// TODO: anchor
assert(num_k_anchor == 0);
if (rotary_angle != nullptr) {
ggml_fp16_t *k_cache_with_rope_fp16 =
(reinterpret_cast<ggml_fp16_t *>(draft) +
sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +
sizeof(float) * bsz * head_dim);
block_q4_0 *k_cache_with_rope_q4 =
(reinterpret_cast<block_q4_0 *>(draft) +
sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +
sizeof(float) * bsz * head_dim) +
sizeof(ggml_fp16_t) * bsz * head_dim;
// dequant k_cache and apply rope
// k_rope(i) = k(i) * cos(i) - k(i+l) * sin(i)
// k_rope(i+l) = k(i+l) * cos(i+l) + k(i) * sin(i)
// k(i)cos(i) -> k_rope(i)
// k(i)sin(i+l) -> k_rope(i+l)
// k(i)cos(i) -> k_rope(i)
// -k(i)sin(i-l) -> k_rope(i-l)
std::vector<float> block_fp32(32);
for (int k = 0; k < past_kv_len; k++) {
int angle = rotary_angle[k];
for (int l = 0; l < head_dim / 32; l++) {
block_q4_0 block =
((block_q4_0 *)k_cache)[k * head_dim / 32 + l];
dequantize_row_q4_0(&block, block_fp32.data(), 32);
for (int m = 0; m < 32; m++) {
float sin_val = GGML_FP16_TO_FP32(
((ggml_fp16_t *)
rotary_sin)[angle * head_dim + l * 32 + m]);
float cos_val = GGML_FP16_TO_FP32(
((ggml_fp16_t *)
rotary_cos)[angle * head_dim + l * 32 + m]);
if (l * 32 + m < head_dim / 2) {
k_cache_with_rope_fp16[k * head_dim + l * 32 + m] =
GGML_FP32_TO_FP16(block_fp32[m] * cos_val);
k_cache_with_rope_fp16[k * head_dim + l * 32 + m +
head_dim / 2] =
GGML_FP32_TO_FP16(-block_fp32[m] * sin_val);
} else {
k_cache_with_rope_fp16[k * head_dim + l * 32 + m] +=
GGML_FP32_TO_FP16(block_fp32[m] * sin_val);
k_cache_with_rope_fp16[k * head_dim + l * 32 + m -
head_dim / 2] -=
GGML_FP32_TO_FP16(block_fp32[m] * cos_val);
}
}
}
}
// quantize k_cache_with_rope_fp16
for (int k = 0; k < past_kv_len; k++) {
for (int l = 0; l < head_dim / 32; l++) {
for (int m = 0; m < 32; m++) {
block_fp32[m] = GGML_FP16_TO_FP32(
k_cache_with_rope_fp16[k * head_dim + l * 32 + m]);
}
quantize_row_q4_0(
block_fp32.data(),
&k_cache_with_rope_q4[k * head_dim / 32 + l], 32);
}
}
llamafile_sgemm(past_kv_len, bsz, head_dim / 32,
(block_q4_0 *)k_cache_with_rope_q4, head_dim / 32,
(block_q8_0 *)q, head_dim / 32, attn_score,
past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type,
GGML_TYPE_Q8_0, GGML_TYPE_F32, GGML_PREC_DEFAULT);
} else {
llamafile_sgemm(past_kv_len, bsz, head_dim / 32,
(block_q4_0 *)k_cache, head_dim / 32,
(block_q8_0 *)q, head_dim / 32, attn_score,
past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type,
GGML_TYPE_Q8_0, GGML_TYPE_F32, GGML_PREC_DEFAULT);
}
// attn = attn * scale
float scale_factor = 1.0 / std::sqrt(float(head_dim));
ggml_vec_scale_f32(bsz * past_kv_len, attn_score, scale_factor);
// attn = attn & mask
if (!is_full_attn) {
for (int i = 0; i < bsz; i++) {
for (int j = 0; j < past_kv_len; j++) {
int index = i * past_kv_len + j;
if (!(attn_mask[j / 8] & (1 << (j % 8)))) {
attn_score[index] =
std::numeric_limits<float>::lowest();
}
}
}
}
// attn = softmax(attn)
for (int i = 0; i < bsz; i++) {
float sum_exp = 0;
for (int j = 0; j < past_kv_len; j++) {
attn_score[i * past_kv_len + j] =
std::exp(attn_score[i * past_kv_len + j]);
sum_exp += attn_score[i * past_kv_len + j];
}
for (int j = 0; j < past_kv_len; j++) {
attn_score[i * past_kv_len + j] /= sum_exp;
}
if (lse != nullptr) {
lse[i] = std::log(sum_exp);
}
}
// output = attn * v + attn * v_anchor
// std::vector<block_q8_0> attn_q8_0(bsz * past_kv_len / QK8_0);
block_q8_0 *attn_q8_0 = reinterpret_cast<block_q8_0 *>(draft);
quantize_row_q8_0(attn_score, attn_q8_0, bsz * past_kv_len);
// std::vector<float> sum(bsz * head_dim);
float *sum = reinterpret_cast<float *>(reinterpret_cast<char *>(draft) +
sizeof(block_q8_0) * bsz *
past_kv_len / QK8_0);
// TODO: anchor
assert(num_v_anchor == 0);
llamafile_sgemm(head_dim, bsz, past_kv_len / 32, (block_q4_0 *)v_cache,
past_kv_len / 32, attn_q8_0, past_kv_len / 32, sum,
head_dim, 0, 1, GGML_TASK_TYPE_COMPUTE, v_type,
GGML_TYPE_Q8_0, GGML_TYPE_F32, GGML_PREC_DEFAULT);
quantize_row_q8_0(sum, (block_q8_0 *)output, bsz * head_dim);
}
}
/**
* @Description :
* @Author : Jianwei Dong
* @Date : 2024-08-26 22:47:06
* @Version : 1.0.0
* @LastEditors : Jianwei Dong
* @LastEditTime : 2024-08-26 22:47:06
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "kvcache.h"
void KVCache::load_kvcache(std::string tensor_file_path, Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
std::ifstream ifs_tensor(tensor_file_path, std::ios::binary);
if (!ifs_tensor) {
throw std::runtime_error("Failed to open tensor file");
}
ifs_tensor.read(reinterpret_cast<char *>(&cache_total_len_),
sizeof(cache_total_len_));
int past_block_num =
(cache_total_len_ + config_.block_len - 1) / config_.block_len;
printf("cache_total_len: %d, past_block_num: %d\n", cache_total_len_,
past_block_num);
for (int i = 0; i < config_.layer_num; ++i) {
past_block_num_[i] = past_block_num;
}
ifs_tensor.read(reinterpret_cast<char *>(anchor_.data()),
anchor_.size() * sizeof(ggml_fp16_t));
for (int i = 0; i < config_.layer_num; ++i) {
for (int j = 0; j < config_.kv_head_num; ++j) {
for (int k = 0; k < past_block_num_[i]; ++k) {
if (config_.kv_type == GGML_TYPE_F16) {
ifs_tensor.read(
reinterpret_cast<char *>(k_cache_fp16_[i][j][k].data()),
k_cache_fp16_[i][j][k].size() * sizeof(ggml_fp16_t));
ifs_tensor.read(
reinterpret_cast<char *>(v_cache_fp16_[i][j][k].data()),
v_cache_fp16_[i][j][k].size() * sizeof(ggml_fp16_t));
} else if (config_.kv_type == GGML_TYPE_Q4_0) {
ifs_tensor.read(
reinterpret_cast<char *>(k_cache_q4[i][j][k].data()),
k_cache_q4[i][j][k].size() * sizeof(block_q4_0));
ifs_tensor.read(
reinterpret_cast<char *>(v_cache_q4[i][j][k].data()),
v_cache_q4[i][j][k].size() * sizeof(block_q4_0));
}
}
}
for (int k = 0; k < past_block_num_[i]; ++k) {
for (int l = 0; l < config_.block_len; l++) {
ifs_tensor.read(
reinterpret_cast<char *>(importance_[i][k][l].data()),
importance_[i][k][l].size() * sizeof(ggml_fp16_t));
}
}
}
ifs_tensor.close();
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
printf("time of load: %f s\n", diff.count());
}
void KVCache::dump_kvcache(int *block_table, int cache_total_len,
std::string tensor_file_path, Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
std::ofstream ofs(tensor_file_path, std::ios::binary);
printf("dump_kvcache: %s\n", tensor_file_path.c_str());
if (!ofs.is_open()) {
std::cerr << "Cannot open file " << tensor_file_path << std::endl;
return;
}
ofs.write(reinterpret_cast<const char *>(&cache_total_len),
sizeof(cache_total_len));
int past_block_num =
(cache_total_len + config_.block_len - 1) / config_.block_len;
printf("cache_total_len: %d, past_block_num: %d\n", cache_total_len,
past_block_num);
ofs.write(reinterpret_cast<const char *>(anchor_.data()),
anchor_.size() * sizeof(ggml_fp16_t));
for (int i = 0; i < config_.layer_num; ++i) {
for (int j = 0; j < config_.kv_head_num; ++j) {
for (int k = 0; k < past_block_num; ++k) {
int block_idx = block_table[k];
if (config_.kv_type == GGML_TYPE_F16) {
ofs.write(reinterpret_cast<const char *>(
k_cache_fp16_[i][j][block_idx].data()),
k_cache_fp16_[i][j][block_idx].size() *
sizeof(ggml_fp16_t));
ofs.write(reinterpret_cast<const char *>(
v_cache_fp16_[i][j][block_idx].data()),
v_cache_fp16_[i][j][block_idx].size() *
sizeof(ggml_fp16_t));
} else if (config_.kv_type == GGML_TYPE_Q4_0) {
ofs.write(reinterpret_cast<const char *>(
k_cache_q4[i][j][block_idx].data()),
k_cache_q4[i][j][block_idx].size() *
sizeof(block_q4_0));
ofs.write(reinterpret_cast<const char *>(
v_cache_q4[i][j][block_idx].data()),
v_cache_q4[i][j][block_idx].size() *
sizeof(block_q4_0));
}
}
}
for (int k = 0; k < past_block_num; ++k) {
int block_idx = block_table[k];
for (int l = 0; l < config_.block_len; l++) {
ofs.write(reinterpret_cast<const char *>(
importance_[i][block_idx][l].data()),
importance_[i][block_idx][l].size() *
sizeof(ggml_fp16_t));
}
}
}
ofs.close();
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
printf("time of dump: %f s\n", diff.count());
}
\ No newline at end of file
/**
* @Description :
* @Author : Jianwei Dong
* @Date : 2024-08-26 22:47:06
* @Version : 1.0.0
* @LastEditors : Jianwei Dong
* @LastEditTime : 2024-08-26 22:47:06
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "kvcache.h"
void KVCache::get_anchor_one_block(ggml_fp16_t *anchor, int layer_id,
int block_idx, Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_id;
block_idx = block_idx;
seq_len_ = config_.block_len;
anchor_data_ = const_cast<uint16_t *>(anchor);
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration = end - start;
printf("layer %d block %d time of reading anchor: %f s\n", layer_id,
block_idx, duration.count());
}
void KVCache::update_anchor_one_block(const ggml_fp16_t *anchor, int layer_id,
int block_idx, Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_id;
block_idx = block_idx;
seq_len_ = config_.block_len;
anchor_data_ = const_cast<uint16_t *>(anchor);
// Each task updates the anchor of a certain position
// backend->do_work_stealing_job(config_.anchor_num, [&](int task_id) {
// int k = task_id % config_.anchor_num;
// int head_id = task_id / config_.anchor_num;
// memcpy(anchor_[layer_id_][head_id][block_idx].data() +
// k * config_.head_dim,
// anchor_data_ + k * config_.head_dim,
// sizeof(uint16_t) * config_.head_dim);
// });
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration = end - start;
printf("layer %d block %d time of writting anchor: %f s\n", layer_id,
block_idx, duration.count());
}
void KVCache::update_importance_one_block(const ggml_fp16_t *importance,
int layer_id, int block_idx,
Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_id;
block_idx = block_idx;
seq_len_ = config_.block_len;
importance_data_ = const_cast<uint16_t *>(importance);
// Each task updates the importance of a certain position
backend->do_work_stealing_job(
config_.block_len, nullptr,
[&](int task_id) {
int k = task_id;
memcpy(importance_[layer_id_][block_idx].data() + k,
importance_data_ + k, sizeof(uint16_t));
},
nullptr);
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration = end - start;
printf("layer %d block %d time of writting importance: %f s\n", layer_id,
block_idx, duration.count());
}
void KVCache::get_importance_one_block(ggml_fp16_t *importance, int layer_id,
int block_idx, Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_id;
block_idx = block_idx;
seq_len_ = config_.block_len;
importance_data_ = const_cast<uint16_t *>(importance);
// Each task updates the importance of a certain position
backend->do_work_stealing_job(
config_.block_len, nullptr,
[&](int task_id) {
int k = task_id;
memcpy(importance_data_ + k,
importance_[layer_id_][block_idx].data() + k,
sizeof(uint16_t));
},
nullptr);
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration = end - start;
printf("layer %d block %d time of reading importance: %f s\n", layer_id,
block_idx, duration.count());
}
void KVCache::update_kvcache_one_block_fp16(const ggml_fp16_t *k_in,
const ggml_fp16_t *v_in,
int layer_id, int block_idx,
Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_id;
block_idx = block_idx;
seq_len_ = config_.block_len;
k_data_ = const_cast<uint16_t *>(k_in);
v_data_ = const_cast<uint16_t *>(v_in);
int new_block_num = std::max((int)past_block_num_[layer_id], block_idx + 1);
importance_[layer_id_].resize(new_block_num);
for (int i = 0; i < config_.kv_head_num; i++) {
k_cache_q4[layer_id][i].resize(new_block_num);
v_cache_q4[layer_id][i].resize(new_block_num);
// anchor_[layer_id][i].resize(new_block_num);
}
for (int i = 0; i < new_block_num; i++) {
importance_[layer_id][i].resize(config_.block_len);
}
// Each task updates the k cache or v cache of a certain header
backend->do_work_stealing_job(
config_.kv_head_num * 2, nullptr,
[&](int task_id) {
std::vector<float> block_fp32(32);
int head_id = task_id / 2;
if (task_id & 1) {
// fill k_cache_
k_cache_q4[layer_id_][head_id][block_idx].resize(
config_.block_len * config_.head_dim / 32);
for (int k = 0; k < config_.block_len; k++) {
for (int l = 0; l < config_.head_dim / 32; l++) {
block_q4_0 block;
for (int m = 0; m < 32; m++) {
block_fp32[m] = GGML_FP16_TO_FP32(
k_data_[((0 * config_.kv_head_num + head_id) *
seq_len_ +
0 * config_.block_len + k) *
config_.head_dim +
l * 32 + m]);
}
quantize_row_q4_0(block_fp32.data(), &block, 32);
k_cache_q4[layer_id_][head_id][block_idx]
[k * config_.head_dim / 32 + l] = block;
}
}
} else {
// fill v_cache_
v_cache_q4[layer_id_][head_id][block_idx].resize(
config_.head_dim * config_.block_len / 32);
for (int k = 0; k < config_.block_len / 32; k++) {
for (int l = 0; l < config_.head_dim; l++) {
block_q4_0 block;
for (int m = 0; m < 32; m++) {
block_fp32[m] = GGML_FP16_TO_FP32(
v_data_[((0 * config_.kv_head_num + head_id) *
seq_len_ +
0 * config_.block_len + k * 32 + m) *
config_.head_dim +
l]);
}
quantize_row_q4_0(block_fp32.data(), &block, 32);
v_cache_q4[layer_id_][head_id][block_idx]
[l * config_.block_len / 32 + k] = block;
}
}
}
},
nullptr);
past_block_num_[layer_id] = new_block_num;
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration = end - start;
printf("layer %d block %d time of writting KV Cache: %f s\n", layer_id,
block_idx, duration.count());
// printf("get_one_block_fp16 duration: %ld\n", duration);
}
void KVCache::get_kvcache_one_block_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,
int layer_id, int block_idx,
Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_id;
seq_len_ = config_.block_len;
k_data_ = reinterpret_cast<uint16_t *>(k_in);
v_data_ = reinterpret_cast<uint16_t *>(v_in);
// printf("layer_id: %d, block_idx: %d\n", layer_id, block_idx);
// Each task gets the k cache or v cache of a certain header
backend->do_work_stealing_job(
config_.kv_head_num * 2, nullptr,
[&](int task_id) {
std::vector<float> block_fp32(32);
int head_id = task_id / 2;
if (task_id & 1) {
// get k_cache_
for (int k = 0; k < config_.block_len; k++) {
for (int l = 0; l < config_.head_dim / 32; l++) {
block_q4_0 block =
k_cache_q4[layer_id_][head_id][block_idx]
[k * config_.head_dim / 32 + l];
dequantize_row_q4_0(&block, block_fp32.data(), 32);
for (int m = 0; m < 32; m++) {
k_data_[((0 * config_.kv_head_num + head_id) *
seq_len_ +
0 * config_.block_len + k) *
config_.head_dim +
l * 32 + m] =
GGML_FP32_TO_FP16(block_fp32[m]);
}
}
}
} else {
// get v_cache_
for (int k = 0; k < config_.block_len / 32; k++) {
for (int l = 0; l < config_.head_dim; l++) {
block_q4_0 block =
v_cache_q4[layer_id_][head_id][block_idx]
[l * config_.block_len / 32 + k];
dequantize_row_q4_0(&block, block_fp32.data(), 32);
for (int m = 0; m < 32; m++) {
v_data_[((0 * config_.kv_head_num + head_id) *
seq_len_ +
0 * config_.block_len + k * 32 + m) *
config_.head_dim +
l] = GGML_FP32_TO_FP16(block_fp32[m]);
}
}
}
}
},
nullptr);
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration = end - start;
printf("layer %d block %d time of reading KV Cache: %f s\n", layer_id,
block_idx, duration.count());
// printf("get_one_block_fp16 duration: %ld\n", duration);
}
// k_in: (batch_size, seq_len, head_num, head_dim)
// v_in: (batch_size, seq_len, head_num, head_dim)
void KVCache::get_and_update_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,
int layer_id, int *block_table,
int batch_size, int max_block_num,
int *cache_seqlens, int q_len,
Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_id;
k_data_ = const_cast<uint16_t *>(k_in);
v_data_ = const_cast<uint16_t *>(v_in);
// Each task updates the k cache and v cache of a certain header
backend->do_work_stealing_job(
config_.kv_head_num * max_block_num * batch_size, nullptr,
[&](int task_id) {
// printf("block_idx: %d, task_id: %d\n", block_idx, task_id);
std::vector<float> block_fp32(32);
int batch_id = task_id / (config_.kv_head_num * max_block_num);
int block_id = (task_id / config_.kv_head_num) % max_block_num;
int head_id = task_id % config_.kv_head_num;
int block_idx = block_table[batch_id * max_block_num + block_id];
int seq_len = cache_seqlens[batch_id];
int block_l = block_id * config_.block_len;
int block_r = block_id * config_.block_len + config_.block_len;
if (block_l < seq_len) {
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
for (int k = 0; k < config_.block_len; k++) {
if (block_id * config_.block_len + k >= seq_len)
break;
for (int l = 0; l < config_.head_dim; l++) {
k_data_
[batch_id *
(max_block_num * config_.block_len *
config_.kv_head_num * config_.head_dim) +
block_id *
(config_.block_len * config_.kv_head_num *
config_.head_dim) +
k * (config_.kv_head_num * config_.head_dim) +
head_id * config_.head_dim + l] =
k_cache_fp16_[layer_id_][head_id][block_idx]
[k * config_.head_dim + l];
v_data_
[batch_id *
(max_block_num * config_.block_len *
config_.kv_head_num * config_.head_dim) +
block_id *
(config_.block_len * config_.kv_head_num *
config_.head_dim) +
k * (config_.kv_head_num * config_.head_dim) +
head_id * config_.head_dim + l] =
v_cache_fp16_[layer_id_][head_id][block_idx]
[l * config_.block_len + k];
}
}
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
// get k_cache_
for (int k = 0; k < config_.block_len; k++) {
if (block_id * config_.block_len + k >= seq_len)
break;
for (int l = 0; l < config_.head_dim / 32; l++) {
block_q4_0 block =
k_cache_q4[layer_id_][head_id][block_idx]
[k * config_.head_dim / 32 + l];
dequantize_row_q4_0(&block, block_fp32.data(), 32);
for (int m = 0; m < 32; m++) {
k_data_[batch_id *
(max_block_num * config_.block_len *
config_.kv_head_num *
config_.head_dim) +
block_id * (config_.block_len *
config_.kv_head_num *
config_.head_dim) +
k * (config_.kv_head_num *
config_.head_dim) +
head_id * config_.head_dim + l * 32 +
m] = GGML_FP32_TO_FP16(block_fp32[m]);
}
}
}
// get v_cache_
for (int k = 0; k < config_.block_len / 32; k++) {
for (int l = 0; l < config_.head_dim; l++) {
block_q4_0 block =
v_cache_q4[layer_id_][head_id][block_idx]
[l * config_.block_len / 32 + k];
dequantize_row_q4_0(&block, block_fp32.data(), 32);
for (int m = 0; m < 32; m++) {
if (block_id * config_.block_len + k * 32 + m >=
seq_len)
break;
v_data_[batch_id *
(max_block_num * config_.block_len *
config_.kv_head_num *
config_.head_dim) +
block_id * (config_.block_len *
config_.kv_head_num *
config_.head_dim) +
(k * 32 + m) * config_.kv_head_num *
config_.head_dim +
head_id * config_.head_dim + l] =
GGML_FP32_TO_FP16(block_fp32[m]);
}
}
}
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
// get k_cache_
for (int k = 0; k < config_.block_len; k++) {
if (block_id * config_.block_len + k >= seq_len)
break;
for (int l = 0; l < config_.head_dim / 32; l++) {
block_q8_0 block =
k_cache_q8[layer_id_][head_id][block_idx]
[k * config_.head_dim / 32 + l];
dequantize_row_q8_0(&block, block_fp32.data(), 32);
for (int m = 0; m < 32; m++) {
k_data_[batch_id *
(max_block_num * config_.block_len *
config_.kv_head_num *
config_.head_dim) +
block_id * (config_.block_len *
config_.kv_head_num *
config_.head_dim) +
k * (config_.kv_head_num *
config_.head_dim) +
head_id * config_.head_dim + l * 32 +
m] = GGML_FP32_TO_FP16(block_fp32[m]);
}
}
}
// get v_cache_
for (int k = 0; k < config_.block_len / 32; k++) {
for (int l = 0; l < config_.head_dim; l++) {
block_q8_0 block =
v_cache_q8[layer_id_][head_id][block_idx]
[l * config_.block_len / 32 + k];
dequantize_row_q8_0(&block, block_fp32.data(), 32);
for (int m = 0; m < 32; m++) {
if (block_id * config_.block_len + k * 32 + m >=
seq_len)
break;
v_data_[batch_id *
(max_block_num * config_.block_len *
config_.kv_head_num *
config_.head_dim) +
block_id * (config_.block_len *
config_.kv_head_num *
config_.head_dim) +
(k * 32 + m) * config_.kv_head_num *
config_.head_dim +
head_id * config_.head_dim + l] =
GGML_FP32_TO_FP16(block_fp32[m]);
}
}
}
}
}
if (block_r > seq_len && block_l < seq_len + q_len) {
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
for (int k = 0; k < config_.block_len; k++) {
if (block_id * config_.block_len + k >=
seq_len + q_len ||
block_id * config_.block_len + k < seq_len)
continue;
for (int l = 0; l < config_.head_dim; l++) {
k_cache_fp16_[layer_id_][head_id][block_idx]
[k * config_.head_dim + l] = k_data_
[batch_id * (max_block_num *
config_.block_len *
config_.kv_head_num *
config_.head_dim) +
block_id * (config_.block_len *
config_.kv_head_num *
config_.head_dim) +
k * (config_.kv_head_num *
config_.head_dim) +
head_id * config_.head_dim + l];
v_cache_fp16_[layer_id_][head_id][block_idx]
[l * config_.block_len + k] = v_data_
[batch_id * (max_block_num *
config_.block_len *
config_.kv_head_num *
config_.head_dim) +
block_id * (config_.block_len *
config_.kv_head_num *
config_.head_dim) +
k * (config_.kv_head_num *
config_.head_dim) +
head_id * config_.head_dim + l];
}
}
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
// fill k_cache_
for (int k = 0; k < config_.block_len; k++) {
if (block_id * config_.block_len + k >=
seq_len + q_len ||
block_id * config_.block_len + k < seq_len)
continue;
for (int l = 0; l < config_.head_dim / 32; l++) {
block_q4_0 block;
for (int m = 0; m < 32; m++) {
block_fp32[m] = GGML_FP16_TO_FP32(
k_data_[batch_id * (max_block_num *
config_.block_len *
config_.kv_head_num *
config_.head_dim) +
block_id * (config_.block_len *
config_.kv_head_num *
config_.head_dim) +
k * (config_.kv_head_num *
config_.head_dim) +
head_id * config_.head_dim +
l * 32 + m]);
}
quantize_row_q4_0(block_fp32.data(), &block, 32);
k_cache_q4[layer_id_][head_id][block_idx]
[k * config_.head_dim / 32 + l] = block;
}
}
// fill v_cache_
for (int k = 0; k < config_.block_len / 32; k++) {
for (int l = 0; l < config_.head_dim; l++) {
block_q4_0 block;
for (int m = 0; m < 32; m++) {
if (block_id * config_.block_len + k * 32 + m >=
seq_len + q_len) {
block_fp32[m] = 0;
continue;
}
block_fp32[m] = GGML_FP16_TO_FP32(
v_data_[batch_id * (max_block_num *
config_.block_len *
config_.kv_head_num *
config_.head_dim) +
block_id * (config_.block_len *
config_.kv_head_num *
config_.head_dim) +
(k * 32 + m) * config_.kv_head_num *
config_.head_dim +
head_id * config_.head_dim + l]);
}
quantize_row_q4_0(block_fp32.data(), &block, 32);
v_cache_q4[layer_id_][head_id][block_idx]
[l * config_.block_len / 32 + k] = block;
}
}
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
// fill k_cache_
for (int k = 0; k < config_.block_len; k++) {
if (block_id * config_.block_len + k >=
seq_len + q_len ||
block_id * config_.block_len + k < seq_len)
continue;
for (int l = 0; l < config_.head_dim / 32; l++) {
block_q8_0 block;
for (int m = 0; m < 32; m++) {
block_fp32[m] = GGML_FP16_TO_FP32(
k_data_[batch_id * (max_block_num *
config_.block_len *
config_.kv_head_num *
config_.head_dim) +
block_id * (config_.block_len *
config_.kv_head_num *
config_.head_dim) +
k * (config_.kv_head_num *
config_.head_dim) +
head_id * config_.head_dim +
l * 32 + m]);
}
quantize_row_q8_0(block_fp32.data(), &block, 32);
k_cache_q8[layer_id_][head_id][block_idx]
[k * config_.head_dim / 32 + l] = block;
}
}
// fill v_cache_
for (int k = 0; k < config_.block_len / 32; k++) {
for (int l = 0; l < config_.head_dim; l++) {
block_q8_0 block;
for (int m = 0; m < 32; m++) {
if (block_id * config_.block_len + k * 32 + m >=
seq_len + q_len) {
block_fp32[m] = 0;
continue;
}
block_fp32[m] = GGML_FP16_TO_FP32(
v_data_[batch_id * (max_block_num *
config_.block_len *
config_.kv_head_num *
config_.head_dim) +
block_id * (config_.block_len *
config_.kv_head_num *
config_.head_dim) +
(k * 32 + m) * config_.kv_head_num *
config_.head_dim +
head_id * config_.head_dim + l]);
}
quantize_row_q8_0(block_fp32.data(), &block, 32);
v_cache_q8[layer_id_][head_id][block_idx]
[l * config_.block_len / 32 + k] = block;
}
}
}
}
},
nullptr);
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration = end - start;
// printf("layer %d time of reading and updating KV Cache: %f s\n",
// layer_id,
// duration.count());
}
void KVCache::update_importance(const ggml_fp16_t *importance, int layer_id,
int *block_table, int batch_size,
int max_block_num, int *offset, int width,
Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_id;
importance_data_ = const_cast<uint16_t *>(importance);
// Each task updates the importance of a certain position
backend->do_work_stealing_job(
max_block_num * batch_size, nullptr,
[&](int task_id) {
int block_id = task_id % max_block_num;
int batch_id = task_id / max_block_num;
int block_idx = block_table[batch_id * max_block_num + block_id];
if (block_id > (offset[batch_id] + width) / config_.block_len) {
return;
}
for (int k = 0; k < config_.block_len; k++) {
for (int head_id = 0; head_id < config_.q_head_num; head_id++) {
importance_[layer_id_][block_idx][k][head_id] =
GGML_FP32_TO_FP16(
GGML_FP16_TO_FP32(
importance_data_[batch_id * max_block_num *
config_.block_len *
config_.q_head_num +
(block_id * config_.block_len +
k) *
config_.q_head_num +
head_id]) +
GGML_FP16_TO_FP32(
importance_[layer_id_][block_idx][k][head_id]));
}
}
},
nullptr);
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration = end - start;
// printf("layer %d time of updating importance: %f s\n", layer_id,
// duration.count());
}
void KVCache::get_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,
int layer_id, int *block_table, int batch_size,
int max_block_num, int *cache_seqlens,
Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_id;
k_data_ = const_cast<uint16_t *>(k_in);
v_data_ = const_cast<uint16_t *>(v_in);
// Each task updates the k cache and v cache of a certain header
backend->do_work_stealing_job(
config_.kv_head_num * max_block_num * batch_size, nullptr,
[&](int task_id) {
// printf("block_idx: %d, task_id: %d\n", block_idx, task_id);
std::vector<float> block_fp32(32);
int batch_id = task_id / (config_.kv_head_num * max_block_num);
int block_id = (task_id / config_.kv_head_num) % max_block_num;
int head_id = task_id % config_.kv_head_num;
int block_idx = block_table[batch_id * max_block_num + block_id];
int seq_len = cache_seqlens[batch_id];
int block_l = block_id * config_.block_len;
int block_r = block_id * config_.block_len + config_.block_len;
if (block_l < seq_len) {
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
for (int k = 0; k < config_.block_len; k++) {
if (block_id * config_.block_len + k >= seq_len)
break;
for (int l = 0; l < config_.head_dim; l++) {
k_data_
[batch_id *
(max_block_num * config_.block_len *
config_.kv_head_num * config_.head_dim) +
block_id *
(config_.block_len * config_.kv_head_num *
config_.head_dim) +
k * (config_.kv_head_num * config_.head_dim) +
head_id * config_.head_dim + l] =
k_cache_fp16_[layer_id_][head_id][block_idx]
[k * config_.head_dim + l];
v_data_
[batch_id *
(max_block_num * config_.block_len *
config_.kv_head_num * config_.head_dim) +
block_id *
(config_.block_len * config_.kv_head_num *
config_.head_dim) +
k * (config_.kv_head_num * config_.head_dim) +
head_id * config_.head_dim + l] =
v_cache_fp16_[layer_id_][head_id][block_idx]
[l * config_.block_len + k];
}
}
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
// get k_cache_
for (int k = 0; k < config_.block_len; k++) {
if (block_id * config_.block_len + k >= seq_len)
break;
for (int l = 0; l < config_.head_dim / 32; l++) {
block_q4_0 block =
k_cache_q4[layer_id_][head_id][block_idx]
[k * config_.head_dim / 32 + l];
dequantize_row_q4_0(&block, block_fp32.data(), 32);
for (int m = 0; m < 32; m++) {
k_data_[batch_id *
(max_block_num * config_.block_len *
config_.kv_head_num *
config_.head_dim) +
block_id * (config_.block_len *
config_.kv_head_num *
config_.head_dim) +
k * (config_.kv_head_num *
config_.head_dim) +
head_id * config_.head_dim + l * 32 +
m] = GGML_FP32_TO_FP16(block_fp32[m]);
}
}
}
// get v_cache_
for (int k = 0; k < config_.block_len / 32; k++) {
for (int l = 0; l < config_.head_dim; l++) {
block_q4_0 block =
v_cache_q4[layer_id_][head_id][block_idx]
[l * config_.block_len / 32 + k];
dequantize_row_q4_0(&block, block_fp32.data(), 32);
for (int m = 0; m < 32; m++) {
if (block_id * config_.block_len + k * 32 + m >=
seq_len)
break;
v_data_[batch_id *
(max_block_num * config_.block_len *
config_.kv_head_num *
config_.head_dim) +
block_id * (config_.block_len *
config_.kv_head_num *
config_.head_dim) +
(k * 32 + m) * config_.kv_head_num *
config_.head_dim +
head_id * config_.head_dim + l] =
GGML_FP32_TO_FP16(block_fp32[m]);
}
}
}
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
// get k_cache_
for (int k = 0; k < config_.block_len; k++) {
if (block_id * config_.block_len + k >= seq_len)
break;
for (int l = 0; l < config_.head_dim / 32; l++) {
block_q8_0 block =
k_cache_q8[layer_id_][head_id][block_idx]
[k * config_.head_dim / 32 + l];
dequantize_row_q8_0(&block, block_fp32.data(), 32);
for (int m = 0; m < 32; m++) {
k_data_[batch_id *
(max_block_num * config_.block_len *
config_.kv_head_num *
config_.head_dim) +
block_id * (config_.block_len *
config_.kv_head_num *
config_.head_dim) +
k * (config_.kv_head_num *
config_.head_dim) +
head_id * config_.head_dim + l * 32 +
m] = GGML_FP32_TO_FP16(block_fp32[m]);
}
}
}
// get v_cache_
for (int k = 0; k < config_.block_len / 32; k++) {
for (int l = 0; l < config_.head_dim; l++) {
block_q8_0 block =
v_cache_q8[layer_id_][head_id][block_idx]
[l * config_.block_len / 32 + k];
dequantize_row_q8_0(&block, block_fp32.data(), 32);
for (int m = 0; m < 32; m++) {
if (block_id * config_.block_len + k * 32 + m >=
seq_len)
break;
v_data_[batch_id *
(max_block_num * config_.block_len *
config_.kv_head_num *
config_.head_dim) +
block_id * (config_.block_len *
config_.kv_head_num *
config_.head_dim) +
(k * 32 + m) * config_.kv_head_num *
config_.head_dim +
head_id * config_.head_dim + l] =
GGML_FP32_TO_FP16(block_fp32[m]);
}
}
}
}
}
},
nullptr);
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration = end - start;
}
void KVCache::update_kvcache_fp16(const ggml_fp16_t *k_in,
const ggml_fp16_t *v_in, int layer_id,
int *block_table, int batch_size,
int max_block_num, int *cache_seqlens,
int q_len, Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_id;
k_data_ = const_cast<uint16_t *>(k_in);
v_data_ = const_cast<uint16_t *>(v_in);
// Each task updates the k cache and v cache of a certain header
backend->do_work_stealing_job(
batch_size * config_.kv_head_num * q_len, nullptr,
[&](int task_id) {
int batch_id = task_id / (config_.kv_head_num * q_len);
int head_id = task_id / q_len % config_.kv_head_num;
int seq_len = cache_seqlens[batch_id] + task_id % q_len;
int q_offset = task_id % q_len;
int block_id = seq_len / config_.block_len;
int block_idx = block_table[batch_id * max_block_num + block_id];
int pos_in_block = seq_len % config_.block_len;
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
for (int l = 0; l < config_.head_dim; l++) {
k_cache_fp16_[layer_id_][head_id][block_idx]
[pos_in_block * config_.head_dim + l] =
k_data_[batch_id *
(q_len * config_.kv_head_num *
config_.head_dim) +
q_offset * config_.kv_head_num *
config_.head_dim +
head_id * config_.head_dim + l];
v_cache_fp16_[layer_id_][head_id][block_idx]
[l * config_.block_len + pos_in_block] =
v_data_[batch_id *
(q_len * config_.kv_head_num *
config_.head_dim) +
q_offset * config_.kv_head_num *
config_.head_dim +
head_id * config_.head_dim + l];
}
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
std::vector<float> block_fp32(32);
// fill k_cache_
for (int l = 0; l < config_.head_dim / 32; l++) {
block_q4_0 block;
for (int m = 0; m < 32; m++) {
block_fp32[m] = GGML_FP16_TO_FP32(
k_data_[batch_id * (q_len * config_.kv_head_num *
config_.head_dim) +
head_id * config_.head_dim + l * 32 + m]);
}
quantize_row_q4_0(block_fp32.data(), &block, 32);
k_cache_q4[layer_id_][head_id][block_idx]
[pos_in_block * config_.head_dim / 32 + l] =
block;
}
// fill v_cache_
for (int l = 0; l < config_.head_dim; l++) {
block_q4_0 block = v_cache_q4[layer_id_][head_id][block_idx]
[l * config_.block_len / 32 +
pos_in_block / 32];
dequantize_row_q4_0(&block, block_fp32.data(), 32);
block_fp32[pos_in_block % 32] = GGML_FP16_TO_FP32(
v_data_[batch_id * (q_len * config_.kv_head_num *
config_.head_dim) +
head_id * config_.head_dim + l]);
quantize_row_q4_0(block_fp32.data(), &block, 32);
v_cache_q4[layer_id_][head_id][block_idx]
[l * config_.block_len / 32 + pos_in_block / 32] =
block;
}
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
std::vector<float> block_fp32(32);
// fill k_cache_
for (int l = 0; l < config_.head_dim / 32; l++) {
block_q8_0 block;
for (int m = 0; m < 32; m++) {
block_fp32[m] = GGML_FP16_TO_FP32(
k_data_[batch_id * (q_len * config_.kv_head_num *
config_.head_dim) +
head_id * config_.head_dim + l * 32 + m]);
}
quantize_row_q8_0(block_fp32.data(), &block, 32);
k_cache_q8[layer_id_][head_id][block_idx]
[pos_in_block * config_.head_dim / 32 + l] =
block;
}
// fill v_cache_
for (int l = 0; l < config_.head_dim; l++) {
block_q8_0 block = v_cache_q8[layer_id_][head_id][block_idx]
[l * config_.block_len / 32 +
pos_in_block / 32];
dequantize_row_q8_0(&block, block_fp32.data(), 32);
block_fp32[pos_in_block % 32] = GGML_FP16_TO_FP32(
v_data_[batch_id * (q_len * config_.kv_head_num *
config_.head_dim) +
head_id * config_.head_dim + l]);
quantize_row_q8_0(block_fp32.data(), &block, 32);
v_cache_q8[layer_id_][head_id][block_idx]
[l * config_.block_len / 32 + pos_in_block / 32] =
block;
}
}
},
nullptr);
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration = end - start;
// printf("layer %d time of reading KV Cache: %f s\n", layer_id,
// duration.count());
}
void KVCache::get_all_kvcache_one_layer(int layer_id, ggml_fp16_t *k_in,
ggml_fp16_t *v_in, Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_id;
seq_len_ = config_.block_len;
block_num_ = get_cache_total_block_num();
k_data_ = reinterpret_cast<uint16_t *>(k_in);
v_data_ = reinterpret_cast<uint16_t *>(v_in);
// Each task gets the k cache or v cache of a certain header
backend->do_work_stealing_job(
config_.kv_head_num * past_block_num_[layer_id] * 2, nullptr,
[&](int task_id) {
std::vector<float> block_fp32(32);
int head_id = task_id / 2 / past_block_num_[layer_id];
int block_idx = task_id / 2 % past_block_num_[layer_id];
if (block_idx >= block_num_)
return;
int max_offset = 0;
if (task_id & 1) {
// get k_cache_
for (int k = 0; k < config_.block_len; k++) {
if (block_idx * seq_len_ + k >= cache_total_len_)
break;
for (int l = 0; l < config_.head_dim / 32; l++) {
block_q4_0 block =
k_cache_q4[layer_id_][head_id][block_idx]
[k * config_.head_dim / 32 + l];
dequantize_row_q4_0(&block, block_fp32.data(), 32);
for (int m = 0; m < 32; m++) {
k_data_[(head_id * cache_total_len_ +
block_idx * config_.block_len + k) *
config_.head_dim +
l * 32 + m] =
GGML_FP32_TO_FP16(block_fp32[m]);
max_offset = std::max(
max_offset,
(int)(head_id * cache_total_len_ +
block_idx * config_.block_len + k) *
config_.head_dim +
l * 32 + m);
}
}
}
} else {
// get v_cache_
for (int k = 0; k < config_.block_len / 32; k++) {
for (int l = 0; l < config_.head_dim; l++) {
block_q4_0 block =
v_cache_q4[layer_id_][head_id][block_idx]
[l * config_.block_len / 32 + k];
dequantize_row_q4_0(&block, block_fp32.data(), 32);
for (int m = 0; m < 32; m++) {
if (block_idx * seq_len_ + k * 32 + m >=
cache_total_len_)
break;
v_data_[(head_id * cache_total_len_ +
block_idx * config_.block_len + k * 32 +
m) *
config_.head_dim +
l] = GGML_FP32_TO_FP16(block_fp32[m]);
max_offset =
std::max(max_offset,
(int)((head_id * cache_total_len_ +
block_idx * config_.block_len +
k * 32 + m) *
config_.head_dim +
l));
}
}
}
}
},
nullptr);
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration = end - start;
// printf("layer %d block num %d time of reading all KV Cache: %f s\n",
// layer_id, block_num_, duration.count());
}
/**
* @Description :
* @Author : Jianwei Dong
* @Date : 2024-08-26 22:47:06
* @Version : 1.0.0
* @LastEditors : Jianwei Dong
* @LastEditTime : 2024-08-26 22:47:06
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "kvcache.h"
std::string ggml_type_to_string(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return "GGML_TYPE_F32";
case GGML_TYPE_F16:
return "GGML_TYPE_F16";
case GGML_TYPE_Q4_0:
return "GGML_TYPE_Q4_0";
case GGML_TYPE_Q8_0:
return "GGML_TYPE_Q8_0";
}
return "UNDIFINED";
}
std::string AnchorTypeToString(AnchorType type) {
switch (type) {
case AnchorType::DYNAMIC:
return "DYNAMIC";
case AnchorType::BLOCK_MEAN:
return "BLOCK_MEAN";
case AnchorType::BLOCK_MAX:
return "BLOCK_MAX";
case AnchorType::FIXED_ANCHOR:
return "FIXED_ANCHOR";
case AnchorType::QUEST:
return "QUEST";
}
return "UNDIFINED";
}
std::string RetrievalTypeToString(RetrievalType type) {
switch (type) {
case RetrievalType::LAYER:
return "SHARED";
case RetrievalType::KVHEAD:
return "SEPARATE";
case RetrievalType::QHEAD:
return "INDIVIDUAL";
}
return "UNDIFINED";
}
KVCacheConfig::KVCacheConfig(int layer_num, int kv_head_num, int q_head_num,
int head_dim, int block_len, int anchor_num,
AnchorType anchor_type, ggml_type kv_type,
RetrievalType retrieval_type, int layer_step,
int token_step, int layer_offset,
int max_block_num, int max_batch_size,
int max_thread_num)
: layer_num(layer_num), kv_head_num(kv_head_num), q_head_num(q_head_num),
head_dim(head_dim), block_len(block_len), anchor_num(anchor_num),
anchor_type(anchor_type), kv_type(kv_type),
retrieval_type(retrieval_type), layer_step(layer_step),
token_step(token_step), layer_offset(layer_offset),
max_block_num(max_block_num), max_batch_size(max_batch_size),
max_thread_num(max_thread_num) {
printf(
"layer_num: %d, kv_head_num: %d, q_head_num: %d, head_dim: %d, "
"block_len: %d, anchor_num: %d, anchor_type: %s, kv_type: %s, "
"retrieval_type: %s, layer_step: %d, token_step: %d, layer_offset: %d,"
"max_block_num: %d, max_batch_size: %d, max_thread_num: %d\n",
layer_num, kv_head_num, q_head_num, head_dim, block_len, anchor_num,
AnchorTypeToString(anchor_type).c_str(),
ggml_type_to_string(kv_type).c_str(),
RetrievalTypeToString(retrieval_type).c_str(), layer_step, token_step,
layer_offset, max_block_num, max_batch_size, max_thread_num);
assert(q_head_num % kv_head_num == 0);
}
KVCache::KVCache(KVCacheConfig config) {
this->config_ = config;
n_gqa_ = config_.q_head_num / config_.kv_head_num;
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
// TODO: Elegant implement
k_cache_fp16_.resize(config_.layer_num);
v_cache_fp16_.resize(config_.layer_num);
selected_blocks_num_history_.resize(config_.layer_num /
config_.layer_step);
if (config_.retrieval_type == RetrievalType::LAYER) {
selected_blocks_history_.resize(config_.layer_num /
config_.layer_step);
} else if (config_.retrieval_type == RetrievalType::KVHEAD) {
selected_blocks_history_kvhead_.resize(config_.layer_num /
config_.layer_step);
} else if (config_.retrieval_type == RetrievalType::QHEAD) {
}
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
k_cache_q4.resize(config.layer_num);
v_cache_q4.resize(config.layer_num);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
k_cache_q8.resize(config.layer_num);
v_cache_q8.resize(config.layer_num);
} else {
assert(false);
}
anchor_.resize(config.layer_num * config.max_block_num * config.anchor_num *
config.q_head_num * config.head_dim);
importance_.resize(config.layer_num);
past_block_num_.resize(config.layer_num);
for (int i = 0; i < config.layer_num; i++) {
past_block_num_[i] = 0;
}
ThreadResize(config.max_thread_num);
BatchResize(config.max_batch_size);
BlockResize(config.max_block_num);
q_fp32.resize(n_gqa_ * config.head_dim);
}
void KVCache::ThreadResize(int thread_num) {
thread_local_output_q8_0_.resize(thread_num);
thread_local_attn_score_.resize(thread_num);
thread_local_output_fp32_.resize(thread_num);
thread_local_attn_lse_.resize(thread_num);
thread_local_cur_output_fp32_.resize(thread_num);
thread_local_cur_attn_lse_.resize(thread_num);
thread_local_draft_.resize(thread_num);
thread_cur_head_idx_.resize(thread_num);
thread_local_attn_mask_.resize(thread_num);
for (int i = 0; i < thread_num; i++) {
thread_local_output_q8_0_[i].resize(n_gqa_ * config_.head_dim / QK8_0);
thread_local_attn_score_[i].resize(n_gqa_ * config_.block_len);
thread_local_output_fp32_[i].resize(n_gqa_ * config_.head_dim);
thread_local_attn_lse_[i].resize(n_gqa_);
thread_local_cur_output_fp32_[i].resize(n_gqa_ * config_.head_dim);
thread_local_cur_attn_lse_[i].resize(n_gqa_);
thread_local_draft_[i].resize(
2 * n_gqa_ * config_.block_len + 6 * n_gqa_ * config_.head_dim +
2 * config_.block_len * config_.head_dim +
config_.block_len * config_.head_dim / QK4_0);
thread_local_attn_mask_[i].resize(config_.block_len / 8);
}
}
void KVCache::BatchResize(int batch_size) {
mutex_.resize(batch_size);
q_q8_0_.resize(batch_size);
q_fp32_.resize(batch_size);
output_fp32_.resize(batch_size);
attn_lse_.resize(batch_size);
block_lse_.resize(batch_size);
attn_sparsity_.resize(batch_size);
if (config_.retrieval_type == RetrievalType::LAYER) {
block_table_before_retrieval_.resize(batch_size);
block_table_after_retrieval_.resize(batch_size);
for (int i = 0; i < config_.layer_num / config_.layer_step; i++) {
selected_blocks_history_[i].resize(batch_size);
}
} else if (config_.retrieval_type == RetrievalType::KVHEAD) {
block_table_before_retrieval_kvhead_.resize(batch_size);
block_table_after_retrieval_kvhead_.resize(batch_size);
for (int i = 0; i < config_.layer_num / config_.layer_step; i++) {
selected_blocks_history_kvhead_[i].resize(batch_size);
}
} else if (config_.retrieval_type == RetrievalType::QHEAD) {
block_table_before_retrieval_qhead_.resize(batch_size);
block_table_after_retrieval_qhead_.resize(batch_size);
}
cache_seqlens_.resize(batch_size);
if (config_.retrieval_type == RetrievalType::LAYER) {
block_similar_.resize(batch_size);
} else if (config_.retrieval_type == RetrievalType::KVHEAD) {
block_similar_kv_head_.resize(batch_size);
} else if (config_.retrieval_type == RetrievalType::QHEAD) {
block_similar_q_head_.resize(batch_size);
}
for (int i = 0; i < batch_size; i++) {
top_similar_block_.resize(batch_size);
mutex_[i].resize(config_.kv_head_num);
q_q8_0_[i].resize(config_.kv_head_num);
q_fp32_[i].resize(config_.kv_head_num);
output_fp32_[i].resize(config_.kv_head_num);
attn_lse_[i].resize(config_.kv_head_num);
for (int j = 0; j < config_.kv_head_num; j++) {
if (!mutex_[i][j]) {
mutex_[i][j] = std::make_unique<std::mutex>();
}
q_q8_0_[i][j].resize(n_gqa_ * config_.head_dim / QK8_0);
q_fp32_[i][j].resize(n_gqa_ * config_.head_dim);
output_fp32_[i][j].resize(n_gqa_ * config_.head_dim);
attn_lse_[i][j].resize(n_gqa_);
}
}
avg_q.resize(batch_size);
avg_q_fp16.resize(batch_size);
for (int i = 0; i < batch_size; i++) {
attn_sparsity_[i].resize(config_.q_head_num);
avg_q[i].resize(config_.q_head_num * config_.head_dim);
avg_q_fp16[i].resize(config_.q_head_num * config_.head_dim);
}
}
void KVCache::BlockResize(int max_block_num) {
sin_.resize(max_block_num * config_.block_len);
cos_.resize(max_block_num * config_.block_len);
for (int i = 0; i < max_block_num * config_.block_len; i++) {
sin_[i].resize(config_.head_dim);
cos_[i].resize(config_.head_dim);
}
for (int i = 0; i < config_.layer_num / config_.layer_step; i++) {
for (int j = 0; j < config_.max_batch_size; j++) {
if (config_.retrieval_type == RetrievalType::LAYER) {
selected_blocks_history_[i][j].resize(max_block_num);
} else if (config_.retrieval_type == RetrievalType::KVHEAD) {
selected_blocks_history_kvhead_[i][j].resize(max_block_num);
for (int k = 0; k < config_.max_block_num; k++) {
selected_blocks_history_kvhead_[i][j][k].resize(
config_.kv_head_num);
}
} else if (config_.retrieval_type == RetrievalType::QHEAD) {
}
}
}
for (int layer_id = 0; layer_id < config_.layer_num; layer_id++) {
importance_[layer_id].resize(max_block_num);
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
// TODO: Elegant implement
k_cache_fp16_[layer_id].resize(config_.kv_head_num);
v_cache_fp16_[layer_id].resize(config_.kv_head_num);
for (int i = 0; i < config_.kv_head_num; i++) {
k_cache_fp16_[layer_id][i].resize(max_block_num);
v_cache_fp16_[layer_id][i].resize(max_block_num);
for (int j = 0; j < max_block_num; j++) {
k_cache_fp16_[layer_id][i][j].resize(config_.block_len *
config_.head_dim);
v_cache_fp16_[layer_id][i][j].resize(config_.block_len *
config_.head_dim);
}
}
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
k_cache_q4[layer_id].resize(config_.kv_head_num);
v_cache_q4[layer_id].resize(config_.kv_head_num);
for (int i = 0; i < config_.kv_head_num; i++) {
k_cache_q4[layer_id][i].resize(max_block_num);
v_cache_q4[layer_id][i].resize(max_block_num);
for (int j = 0; j < max_block_num; j++) {
k_cache_q4[layer_id][i][j].resize(config_.block_len *
config_.head_dim / 32);
v_cache_q4[layer_id][i][j].resize(config_.block_len *
config_.head_dim / 32);
}
}
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
k_cache_q8[layer_id].resize(config_.kv_head_num);
v_cache_q8[layer_id].resize(config_.kv_head_num);
for (int i = 0; i < config_.kv_head_num; i++) {
k_cache_q8[layer_id][i].resize(max_block_num);
v_cache_q8[layer_id][i].resize(max_block_num);
for (int j = 0; j < max_block_num; j++) {
k_cache_q8[layer_id][i][j].resize(config_.block_len *
config_.head_dim / 32);
v_cache_q8[layer_id][i][j].resize(config_.block_len *
config_.head_dim / 32);
}
}
} else {
assert(false);
}
for (int i = 0; i < config_.max_batch_size; i++) {
if (config_.retrieval_type == RetrievalType::LAYER) {
block_similar_[i].resize(max_block_num);
block_table_before_retrieval_[i].resize(max_block_num);
block_table_after_retrieval_[i].resize(max_block_num);
} else if (config_.retrieval_type == RetrievalType::KVHEAD) {
block_similar_kv_head_[i].resize(max_block_num);
block_table_before_retrieval_kvhead_[i].resize(max_block_num);
block_table_after_retrieval_kvhead_[i].resize(max_block_num);
for (int j = 0; j < max_block_num; j++) {
block_similar_kv_head_[i][j].resize(config_.kv_head_num);
block_table_before_retrieval_kvhead_[i][j].resize(
config_.kv_head_num);
block_table_after_retrieval_kvhead_[i][j].resize(
config_.kv_head_num);
}
} else if (config_.retrieval_type == RetrievalType::QHEAD) {
block_similar_q_head_[i].resize(max_block_num);
block_table_before_retrieval_qhead_[i].resize(max_block_num);
block_table_after_retrieval_qhead_[i].resize(max_block_num);
for (int j = 0; j < max_block_num; j++) {
block_similar_q_head_[i][j].resize(config_.q_head_num);
block_table_before_retrieval_qhead_[i][j].resize(
config_.q_head_num);
block_table_after_retrieval_qhead_[i][j].resize(
config_.q_head_num);
}
}
block_lse_[i].resize(max_block_num);
for (int j = 0; j < max_block_num; j++) {
block_lse_[i][j].resize(config_.q_head_num);
}
}
for (int i = 0; i < max_block_num; i++) {
importance_[layer_id][i].resize(config_.block_len);
for (int j = 0; j < config_.block_len; j++) {
importance_[layer_id][i][j].resize(config_.q_head_num);
}
}
}
}
void KVCache::calc_anchor_all_layers(int *block_table, int *cache_seqlens,
int batch_size, int max_block_num,
Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
// Each task updates the importance of a certain block
seq_len_ = config_.block_len;
backend->do_work_stealing_job(
config_.layer_num * batch_size * max_block_num, nullptr,
[&](int task_id) {
int layer_id = task_id / (batch_size * max_block_num);
int batch_id = (task_id / max_block_num) % batch_size;
int block_id = task_id % max_block_num;
// If the block is out of the sequence length, skip it. In
// particular, the last block of the sequence that is shorter than
// the block length should be skipped.
if (cache_seqlens[batch_id] / config_.block_len < block_id) {
return;
}
int block_idx = block_table[batch_id * max_block_num + block_id];
std::vector<float> block_fp32(32);
if (config_.anchor_type == AnchorType::DYNAMIC) {
// clear anchor_
for (int anchor_id = 0; anchor_id < 1; anchor_id++) {
for (int head_id = 0; head_id < config_.q_head_num;
head_id++) {
for (int l = 0; l < config_.head_dim; l++) {
anchor_[layer_id * config_.max_block_num *
config_.anchor_num *
config_.q_head_num * config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num * config_.head_dim +
anchor_id * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim + l] = 0;
}
}
}
// find top anchor_num importances and their corresponding
// positions in the importance_ tensor
// TODO: Move top_importances to the class member to avoid
// repeated memory allocation
std::priority_queue<
std::pair<float, std::pair<int, int>>,
std::vector<std::pair<float, std::pair<int, int>>>,
std::greater<>>
top_importances;
for (int head_id = 0; head_id < config_.q_head_num; head_id++) {
for (int k = 0; k < seq_len_; k++) {
top_importances.push(std::make_pair(
GGML_FP16_TO_FP32(
importance_[layer_id][block_idx][k][head_id]),
std::make_pair(block_idx, k)));
// TODO: change to config_ item
if (top_importances.size() > config_.anchor_num) {
top_importances.pop();
}
}
// fill anchor_
for (int l = 0; l < config_.head_dim; l++) {
anchor_[layer_id * config_.max_block_num *
config_.anchor_num * config_.q_head_num *
config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num * config_.head_dim +
0 * config_.q_head_num * config_.head_dim +
head_id * config_.head_dim + l] = 0;
}
for (int k = 0; k < config_.anchor_num; k++) {
int top_indice = top_importances.top().second.second;
int top_block_idx = top_importances.top().second.first;
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
for (int l = 0; l < config_.head_dim; l++) {
anchor_[layer_id * config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
top_block_idx * config_.anchor_num *
config_.q_head_num *
config_.head_dim +
0 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim + l] =
GGML_FP32_TO_FP16(
GGML_FP16_TO_FP32(
anchor_[layer_id *
config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
top_block_idx *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
0 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim +
l]) +
GGML_FP16_TO_FP32(
k_cache_fp16_[layer_id]
[head_id / n_gqa_]
[top_block_idx]
[top_indice *
config_.head_dim +
l]));
}
} else if (config_.kv_type ==
ggml_type::GGML_TYPE_Q4_0) {
for (int l = 0; l < config_.head_dim / 32; l++) {
block_q4_0 block = k_cache_q4
[layer_id][head_id / n_gqa_][top_block_idx]
[top_indice * config_.head_dim / 32 + l];
dequantize_row_q4_0(&block, block_fp32.data(),
32);
for (int m = 0; m < 32; m++) {
anchor_[layer_id * config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
top_block_idx * config_.anchor_num *
config_.q_head_num *
config_.head_dim +
0 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim +
l * 32 + m] =
GGML_FP32_TO_FP16(
block_fp32[m] / 4 +
GGML_FP16_TO_FP32(
anchor_[layer_id *
config_
.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
top_block_idx *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
0 * config_.q_head_num *
config_.head_dim +
head_id *
config_.head_dim +
l * 32 + m]));
}
}
} else if (config_.kv_type ==
ggml_type::GGML_TYPE_Q8_0) {
for (int l = 0; l < config_.head_dim / 32; l++) {
block_q8_0 block = k_cache_q8
[layer_id][head_id / n_gqa_][top_block_idx]
[top_indice * config_.head_dim / 32 + l];
dequantize_row_q8_0(&block, block_fp32.data(),
32);
for (int m = 0; m < 32; m++) {
anchor_[layer_id * config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
top_block_idx * config_.anchor_num *
config_.q_head_num *
config_.head_dim +
0 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim +
l * 32 + m] =
GGML_FP32_TO_FP16(
block_fp32[m] / 4 +
GGML_FP16_TO_FP32(
anchor_[layer_id *
config_
.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
top_block_idx *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
0 * config_.q_head_num *
config_.head_dim +
head_id *
config_.head_dim +
l * 32 + m]));
}
}
}
top_importances.pop();
}
}
} else if (config_.anchor_type == AnchorType::BLOCK_MEAN) {
// clear anchor_
for (int anchor_id = 0; anchor_id < config_.anchor_num;
anchor_id++) {
for (int head_id = 0; head_id < config_.q_head_num;
head_id++) {
for (int l = 0; l < config_.head_dim; l++) {
anchor_[layer_id * config_.max_block_num *
config_.anchor_num *
config_.q_head_num * config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num * config_.head_dim +
anchor_id * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim + l] = 0;
}
}
}
// fill anchor_
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
for (int head_id = 0; head_id < config_.q_head_num;
head_id++) {
for (int k = 0; k < config_.block_len; k++) {
for (int l = 0; l < config_.head_dim; l++) {
anchor_[layer_id * config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num *
config_.head_dim +
0 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim + l] =
GGML_FP32_TO_FP16(
GGML_FP16_TO_FP32(
anchor_[layer_id *
config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
block_idx *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
0 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim +
l]) +
GGML_FP16_TO_FP32(
k_cache_fp16_[layer_id]
[head_id / n_gqa_]
[block_idx]
[k * config_.head_dim +
l]) /
config_.block_len);
}
}
}
}
} else if (config_.anchor_type == AnchorType::BLOCK_MAX) {
// clear anchor_
for (int anchor_id = 0; anchor_id < config_.anchor_num;
anchor_id++) {
for (int head_id = 0; head_id < config_.q_head_num;
head_id++) {
for (int l = 0; l < config_.head_dim; l++) {
anchor_[layer_id * config_.max_block_num *
config_.anchor_num *
config_.q_head_num * config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num * config_.head_dim +
anchor_id * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim + l] = 0;
}
}
}
// fill anchor_
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
for (int head_id = 0; head_id < config_.q_head_num;
head_id++) {
for (int k = 0; k < config_.block_len; k++) {
for (int l = 0; l < config_.head_dim; l++) {
anchor_[layer_id * config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num *
config_.head_dim +
0 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim + l] =
GGML_FP32_TO_FP16(std::max(
GGML_FP16_TO_FP32(
anchor_[layer_id *
config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
block_idx *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
0 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim +
l]),
GGML_FP16_TO_FP32(
k_cache_fp16_
[layer_id][head_id / n_gqa_]
[block_idx]
[k * config_.head_dim + l])));
}
}
}
}
} else if (config_.anchor_type == AnchorType::FIXED_ANCHOR) {
// clear anchor_
for (int anchor_id = 0; anchor_id < 1; anchor_id++) {
for (int head_id = 0; head_id < config_.q_head_num;
head_id++) {
for (int l = 0; l < config_.head_dim; l++) {
anchor_[layer_id * config_.max_block_num *
config_.anchor_num *
config_.q_head_num * config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num * config_.head_dim +
anchor_id * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim + l] = 0;
}
}
}
// fill anchor_
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
int stride = config_.block_len / config_.anchor_num;
for (int head_id = 0; head_id < config_.q_head_num;
head_id++) {
for (int k = 0, tot = 0;
k < config_.block_len, tot < config_.anchor_num;
k += stride, tot++) {
for (int l = 0; l < config_.head_dim; l++) {
anchor_[layer_id * config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num *
config_.head_dim +
0 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim + l] =
GGML_FP32_TO_FP16(
GGML_FP16_TO_FP32(
anchor_[layer_id *
config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
block_idx *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
0 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim +
l]) +
GGML_FP16_TO_FP32(
k_cache_fp16_[layer_id]
[head_id / n_gqa_]
[block_idx]
[k * config_.head_dim +
l]) /
config_.anchor_num);
}
}
}
}
} else if (config_.anchor_type == AnchorType::QUEST) {
// clear anchor_
for (int head_id = 0; head_id < config_.q_head_num; head_id++) {
for (int l = 0; l < config_.head_dim; l++) {
anchor_[layer_id * config_.max_block_num *
config_.anchor_num * config_.q_head_num *
config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num * config_.head_dim +
1 * config_.q_head_num * config_.head_dim +
head_id * config_.head_dim + l] =
GGML_FP32_TO_FP16(
std::numeric_limits<float>::max());
anchor_[layer_id * config_.max_block_num *
config_.anchor_num * config_.q_head_num *
config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num * config_.head_dim +
0 * config_.q_head_num * config_.head_dim +
head_id * config_.head_dim + l] =
GGML_FP32_TO_FP16(
std::numeric_limits<float>::min());
}
}
// fill anchor_
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
for (int indice = 0; indice < seq_len_; indice++) {
for (int head_id = 0; head_id < config_.kv_head_num;
head_id++) {
for (int l = 0; l < config_.head_dim; l++) {
anchor_[layer_id * config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num *
config_.head_dim +
0 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim + l] =
GGML_FP32_TO_FP16(std::max(
GGML_FP16_TO_FP32(
k_cache_fp16_
[layer_id][head_id][block_idx]
[indice * config_.head_dim +
l]),
GGML_FP16_TO_FP32(
anchor_[layer_id *
config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
block_idx *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
0 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim +
l])));
anchor_[layer_id * config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num *
config_.head_dim +
1 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim + l] =
GGML_FP32_TO_FP16(std::min(
GGML_FP16_TO_FP32(
k_cache_fp16_
[layer_id][head_id][block_idx]
[indice * config_.head_dim +
l]),
GGML_FP16_TO_FP32(
anchor_[layer_id *
config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
block_idx *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
1 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim +
l])));
}
}
}
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
for (int indice = 0; indice < seq_len_; indice++) {
for (int head_id = 0; head_id < config_.kv_head_num;
head_id++) {
for (int l = 0; l < config_.head_dim / 32; l++) {
block_q4_0 block =
k_cache_q4[layer_id][head_id][block_idx]
[indice * config_.head_dim / 32 +
l];
dequantize_row_q4_0(&block, block_fp32.data(),
32);
for (int m = 0; m < 32; m++) {
for (int gqa_idx = 0; gqa_idx < n_gqa_;
gqa_idx++) {
anchor_[layer_id *
config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num *
config_.head_dim +
0 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim +
l * 32 + m] =
GGML_FP32_TO_FP16(std::max(
block_fp32[m],
GGML_FP16_TO_FP32(
anchor_
[layer_id *
config_
.max_block_num *
config_
.anchor_num *
config_
.q_head_num *
config_.head_dim +
block_idx *
config_
.anchor_num *
config_
.q_head_num *
config_.head_dim +
0 *
config_
.q_head_num *
config_.head_dim +
head_id *
config_.head_dim +
l * 32 + m])));
anchor_[layer_id *
config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num *
config_.head_dim +
1 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim +
l * 32 + m] =
GGML_FP32_TO_FP16(std::min(
block_fp32[m],
GGML_FP16_TO_FP32(
anchor_
[layer_id *
config_
.max_block_num *
config_
.anchor_num *
config_
.q_head_num *
config_.head_dim +
block_idx *
config_
.anchor_num *
config_
.q_head_num *
config_.head_dim +
1 *
config_
.q_head_num *
config_.head_dim +
head_id *
config_.head_dim +
l * 32 + m])));
}
}
}
}
}
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
for (int indice = 0; indice < seq_len_; indice++) {
for (int head_id = 0; head_id < config_.kv_head_num;
head_id++) {
for (int l = 0; l < config_.head_dim / 32; l++) {
block_q8_0 block =
k_cache_q8[layer_id][head_id][block_idx]
[indice * config_.head_dim / 32 +
l];
dequantize_row_q8_0(&block, block_fp32.data(),
32);
for (int m = 0; m < 32; m++) {
for (int gqa_idx = 0; gqa_idx < n_gqa_;
gqa_idx++) {
anchor_[layer_id *
config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num *
config_.head_dim +
0 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim +
l * 32 + m] =
GGML_FP32_TO_FP16(std::max(
block_fp32[m],
GGML_FP16_TO_FP32(
anchor_
[layer_id *
config_
.max_block_num *
config_
.anchor_num *
config_
.q_head_num *
config_.head_dim +
block_idx *
config_
.anchor_num *
config_
.q_head_num *
config_.head_dim +
0 *
config_
.q_head_num *
config_.head_dim +
head_id *
config_.head_dim +
l * 32 + m])));
anchor_[layer_id *
config_.max_block_num *
config_.anchor_num *
config_.q_head_num *
config_.head_dim +
block_idx * config_.anchor_num *
config_.q_head_num *
config_.head_dim +
1 * config_.q_head_num *
config_.head_dim +
head_id * config_.head_dim +
l * 32 + m] =
GGML_FP32_TO_FP16(std::min(
block_fp32[m],
GGML_FP16_TO_FP32(
anchor_
[layer_id *
config_
.max_block_num *
config_
.anchor_num *
config_
.q_head_num *
config_.head_dim +
block_idx *
config_
.anchor_num *
config_
.q_head_num *
config_.head_dim +
1 *
config_
.q_head_num *
config_.head_dim +
head_id *
config_.head_dim +
l * 32 + m])));
}
}
}
}
}
}
} else {
assert(false);
}
},
nullptr);
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration = end - start;
// printf("time of calc_anchor_all_layers: %f s\n", duration.count());
}
void KVCache::clear_importance_all_layers(int *block_table, int *cache_seqlens,
int batch_size, int max_block_num,
Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
// Each task updates the importance of a certain block
seq_len_ = config_.block_len;
backend->do_work_stealing_job(
config_.layer_num * batch_size * max_block_num, nullptr,
[&](int task_id) {
int layer_id = task_id / (batch_size * max_block_num);
int batch_id = (task_id / max_block_num) % batch_size;
int block_id = task_id % max_block_num;
// If the block is out of the sequence length, skip it. In
// particular, the last block of the sequence that is shorter than
// the block length should be skipped.
if (cache_seqlens[batch_id] / config_.block_len < block_id) {
return;
}
int block_idx = block_table[batch_id * max_block_num + block_id];
if (config_.anchor_type == AnchorType::DYNAMIC) {
// clear anchor_
for (int head_id = 0; head_id < config_.q_head_num; head_id++) {
for (int l = 0; l < config_.block_len; l++) {
importance_[layer_id][block_idx][l][head_id] = 0;
}
}
}
},
nullptr);
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration = end - start;
// printf("time of clear_importance_all_layerssssss: %f s\n",
// duration.count());
}
void KVCache::clear_kvcache_all_layers(int *block_table, int *cache_seqlens,
int batch_size, int max_block_num,
Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
// Each task updates the importance of a certain block
seq_len_ = config_.block_len;
backend->do_work_stealing_job(
config_.layer_num * batch_size * max_block_num * config_.kv_head_num,
nullptr,
[&](int task_id) {
int layer_id =
task_id / (batch_size * max_block_num * config_.kv_head_num);
int batch_id =
(task_id / (max_block_num * config_.kv_head_num)) % batch_size;
int block_id = task_id / config_.kv_head_num % max_block_num;
int head_id = task_id % config_.kv_head_num;
// If the block is out of the sequence length, skip it. In
// particular, the last block of the sequence that is shorter than
// the block length should be skipped.
if (cache_seqlens[batch_id] / config_.block_len < block_id) {
return;
}
int block_idx = block_table[batch_id * max_block_num + block_id];
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
for (int l = 0; l < config_.block_len * config_.head_dim; l++) {
k_cache_fp16_[layer_id][head_id][block_idx][l] = 0;
v_cache_fp16_[layer_id][head_id][block_idx][l] = 0;
}
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
for (int l = 0; l < config_.block_len * config_.head_dim / 32;
l++) {
k_cache_q4[layer_id][head_id][block_idx][l].d = 0;
v_cache_q4[layer_id][head_id][block_idx][l].d = 0;
}
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
for (int l = 0; l < config_.block_len * config_.head_dim / 32;
l++) {
k_cache_q8[layer_id][head_id][block_idx][l].d = 0;
v_cache_q8[layer_id][head_id][block_idx][l].d = 0;
}
}
},
nullptr);
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration = end - start;
// printf("time of clear_kvcache_all_layers: %f s\n", duration.count());
}
void KVCache::get_sincos(ggml_fp16_t *sin, ggml_fp16_t *cos, int seqlen) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
const uint16_t *sin_data = const_cast<const uint16_t *>(sin);
const uint16_t *cos_data = const_cast<const uint16_t *>(cos);
for (int i = 0; i < seqlen; i++) {
for (int j = 0; j < config_.head_dim; j++) {
sin_[i][j] = sin_data[i * config_.head_dim + j];
cos_[i][j] = cos_data[i * config_.head_dim + j];
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration = end - start;
printf("time of get_sincos: %f s\n", duration.count());
}
void ggml_vec_scale_f32(const int n, float *y, const float v) {
#if defined(GGML_USE_ACCELERATE)
vDSP_vsmul(y, 1, &v, y, 1, n);
#elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F32_STEP - 1));
GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
GGML_F32_VEC ay[GGML_F32_ARR];
for (int i = 0; i < np; i += GGML_F32_STEP) {
for (int j = 0; j < GGML_F32_ARR; j++) {
ay[j] = GGML_F32_VEC_LOAD(y + i + j * GGML_F32_EPR);
ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
GGML_F32_VEC_STORE(y + i + j * GGML_F32_EPR, ay[j]);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] *= v;
}
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] *= v;
}
#endif
}
\ No newline at end of file
......@@ -3,8 +3,8 @@
* @Author : chenht2022
* @Date : 2024-07-12 10:07:58
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:34:58
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-15 07:45:18
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "linear.h"
......@@ -24,10 +24,14 @@ Linear::~Linear() {
shared_mem_buffer.dealloc(this);
}
void Linear::warm_up(Backend* backend) {
void Linear::warm_up(Backend *backend) {
std::vector<float> input_fp32(config_.input_size);
std::vector<uint8_t> input(config_.input_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));
std::vector<uint8_t> output(config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));
std::vector<uint8_t> input(config_.input_size *
ggml_type_size(config_.hidden_type) /
ggml_blck_size(config_.hidden_type));
std::vector<uint8_t> output(config_.output_size *
ggml_type_size(config_.hidden_type) /
ggml_blck_size(config_.hidden_type));
for (int i = 0; i < config_.input_size; i++) {
input_fp32[i] = 0;
}
......@@ -45,7 +49,7 @@ void Linear::forward_many(int qlen, const void* input, void* output, Backend* ba
proj_input_ptr = proj_input_;
}
int nth = config_.output_size / config_.stride;
backend->do_work_stealing_job(nth, [&](int task_id) {
backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {
int ith = task_id;
void* proj_ptr = (uint8_t*)proj_ + ith * config_.stride * config_.input_size * ggml_type_size(config_.proj_type) / ggml_blck_size(config_.proj_type);
float* proj_output_ptr = proj_output_ + ith * config_.stride;
......@@ -57,7 +61,7 @@ void Linear::forward_many(int qlen, const void* input, void* output, Backend* ba
from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);
}
}
});
}, nullptr);
if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {
from_float(proj_output_, output, qlen * config_.output_size, config_.hidden_type);
}
......
......@@ -3,8 +3,8 @@
* @Author : chenht2022
* @Date : 2024-07-16 10:43:18
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:04
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-15 07:44:38
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "mlp.h"
......@@ -31,10 +31,14 @@ MLP::~MLP() {
shared_mem_buffer.dealloc(this);
}
void MLP::warm_up(Backend* backend) {
void MLP::warm_up(Backend *backend) {
std::vector<float> input_fp32(config_.hidden_size);
std::vector<uint8_t> input(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));
std::vector<uint8_t> output(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));
std::vector<uint8_t> input(config_.hidden_size *
ggml_type_size(config_.hidden_type) /
ggml_blck_size(config_.hidden_type));
std::vector<uint8_t> output(config_.hidden_size *
ggml_type_size(config_.hidden_type) /
ggml_blck_size(config_.hidden_type));
for (int i = 0; i < config_.hidden_size; i++) {
input_fp32[i] = 0;
}
......@@ -42,9 +46,7 @@ void MLP::warm_up(Backend* backend) {
forward_many(1, input.data(), output.data(), backend);
}
static float act_fn(float x) {
return x / (1.0f + expf(-x));
}
static float act_fn(float x) { return x / (1.0f + expf(-x)); }
void MLP::forward_many(int qlen, const void* input, void* output, Backend* backend) {
const void* gate_input_ptr;
......@@ -72,7 +74,7 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe
}
}
int nth = config_.intermediate_size / config_.stride;
backend->do_work_stealing_job(nth, [&](int task_id) {
backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {
int ith = task_id;
void* gate_proj_ptr = (uint8_t*)gate_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
float* gate_output_ptr = gate_output_ + ith * config_.stride;
......@@ -90,12 +92,12 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe
from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
}
}
});
}, nullptr);
if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) != 0) {
from_float(intermediate_fp32_, down_input_, qlen * config_.intermediate_size, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
}
nth = config_.hidden_size / config_.stride;
backend->do_work_stealing_job(nth, [&](int task_id) {
backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {
int ith = task_id;
void* down_proj_ptr = (uint8_t*)down_proj_ + ith * config_.stride * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
float* down_output_ptr = down_output_ + ith * config_.stride;
......@@ -107,7 +109,7 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe
from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);
}
}
});
}, nullptr);
if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {
from_float(down_output_, output, qlen * config_.hidden_size, config_.hidden_type);
}
......
......@@ -3,8 +3,8 @@
* @Author : chenht2022
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:07
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-15 07:43:41
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "moe.h"
......@@ -121,7 +121,7 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
}
}
int nth = config_.intermediate_size / config_.stride;
backend->do_work_stealing_job(nth * k, [&](int task_id) {
backend->do_work_stealing_job(nth * k, nullptr, [&](int task_id) {
int expert_idx = task_id / nth;
uint64_t expert_id = expert_ids[expert_idx];
int ith = task_id % nth;
......@@ -139,14 +139,14 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
void* down_input_ptr = s_down_input_[expert_idx] + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
}
});
}, nullptr);
if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) != 0) {
for (int i = 0; i < k; i++) {
from_float(s_intermediate_fp32_[i], s_down_input_[i], config_.intermediate_size, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
}
}
nth = config_.hidden_size / config_.stride;
backend->do_work_stealing_job(nth, [&](int task_id) {
backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {
int ith = task_id;
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
s_output_fp32_[i] = 0;
......@@ -165,7 +165,7 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
void* output_ptr = (uint8_t*)output + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);
}
});
}, nullptr);
if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {
from_float(s_output_fp32_, output, config_.hidden_size, config_.hidden_type);
}
......@@ -191,7 +191,7 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
offset += m_local_num_[i];
}
backend->do_work_stealing_job(qlen, [&](int i) {
backend->do_work_stealing_job(qlen, nullptr, [&](int i) {
const void* gate_input_ptr;
const void* up_input_ptr;
if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
......@@ -220,10 +220,10 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
memcpy(m_local_gate_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type), gate_input_ptr, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type));
memcpy(m_local_up_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type), up_input_ptr, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type));
}
});
}, nullptr);
int stride = QK_K;
int nth = config_.intermediate_size / stride;
backend->do_work_stealing_job(nth * config_.expert_num, [&](int task_id) {
backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) {
int expert_idx = task_id / nth;
int ith = task_id % nth;
void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx];
......@@ -242,18 +242,18 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
void* down_input_ptr = m_local_down_input_ptr_[expert_idx] + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
from_float(intermediate_fp32_ptr, down_input_ptr, stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
}
});
}, nullptr);
stride = QK_K;
nth = config_.hidden_size / stride;
backend->do_work_stealing_job(nth * config_.expert_num, [&](int task_id) {
backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) {
int expert_idx = task_id / nth;
int ith = task_id % nth;
void* down_input_ptr = m_local_down_input_ptr_[expert_idx];
void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
});
backend->do_work_stealing_job(qlen, [&](int i) {
}, nullptr);
backend->do_work_stealing_job(qlen, nullptr, [&](int i) {
for (int e = 0; e < config_.hidden_size; e++) {
m_output_fp32_[i][e] = 0;
}
......@@ -263,7 +263,7 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
}
}
from_float(m_output_fp32_[i], (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), config_.hidden_size, config_.hidden_type);
});
}, nullptr);
}
void MOE::forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) {
......
# Copyright 2024 Shaoyuan Chen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Description :
Author : Boxin Zhang, Azure-Tang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import os
import platform
import sys
project_dir = os.path.dirname(os.path.dirname(__file__))
sys.path.insert(0, project_dir)
import torch
......@@ -31,6 +25,7 @@ import fire
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate
from ktransformers.server.config.config import Config
......@@ -38,38 +33,56 @@ from ktransformers.server.config.config import Config
custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
}
ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
default_optimize_rules ={
ktransformer_rules_dir = (
os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
)
default_optimize_rules = {
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
"LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml",
"MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml",
}
def local_chat(
model_path: str,
model_path: str | None = None,
optimize_rule_path: str = None,
gguf_path: str = None,
gguf_path: str | None = None,
max_new_tokens: int = 1000,
cpu_infer: int = Config().cpu_infer,
use_cuda_graph: bool = True,
prompt_file : str | None = None,
mode: str = "normal",
):
torch.set_grad_enabled(False)
Config().cpu_infer = cpu_infer
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
torch.set_default_dtype(config.torch_dtype)
if mode == 'long_context':
torch.set_default_dtype(torch.float16)
else:
torch.set_default_dtype(config.torch_dtype)
with torch.device("meta"):
if config.architectures[0] in custom_models:
print("using custom modeling_xxx.py.")
if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow.
if (
"Qwen2Moe" in config.architectures[0]
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2"
if "Mixtral" in config.architectures[0]:
if "Llama" in config.architectures[0]:
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
model = custom_models[config.architectures[0]](config)
else:
model = AutoModelForCausalLM.from_config(
......@@ -95,26 +108,50 @@ def local_chat(
if model.generation_config.pad_token_id is None:
model.generation_config.pad_token_id = model.generation_config.eos_token_id
model.eval()
logging.basicConfig(level=logging.INFO)
system = platform.system()
if (system == u'Windows'):
os.system('cls')
if system == "Windows":
os.system("cls")
else:
os.system('clear')
os.system("clear")
while True:
content = input("Chat: ")
if content == "":
content = "Please write a piece of quicksort code in C++."
if content.startswith('"""'): # prefix """
# multi lines input
content = content[3:] + "\n"
while True:
line = input("")
if line.endswith('"""'):
# end multi lines input
line = line[:-3] # suffix """
if line:
content += line + "\n"
break
else:
content += line + "\n"
if content == "":
if prompt_file != None:
content = open(prompt_file, "r").read()
else:
content = "Please write a piece of quicksort code in C++."
elif os.path.isfile(content):
content = open(content, "r").read()
messages = [{"role": "user", "content": content}]
input_tensor = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
)
torch.set_default_dtype(torch.bfloat16) # TODO: Remove this, replace dtype using config
generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph)
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml"
torch.set_default_dtype(
torch.bfloat16
) # TODO: Remove this, replace dtype using config
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode
)
if __name__ == "__main__":
fire.Fire(local_chat)
\ No newline at end of file
fire.Fire(local_chat)
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LLaMA model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
class LlamaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LLaMA-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`LlamaModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
Llama 2 up to 4096, CodeLlama up to 16384.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
```python
>>> from transformers import LlamaModel, LlamaConfig
>>> # Initializing a LLaMA llama-7b style configuration
>>> configuration = LlamaConfig()
>>> # Initializing a model from the llama-7b style configuration
>>> model = LlamaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "llama"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_llama import LlamaConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig"
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
class LlamaRotaryEmbedding(nn.Module):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[LlamaConfig] = None,
):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.device = device
self.scaling_factor = scaling_factor
self.rope_type = rope_type
self.config = config
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.45"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get(
"rope_type", config.rope_scaling.get("type")
)
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
# seq_len = position_ids[0, -1] + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer(
"inv_freq", inv_freq, persistent=False
) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if (
seq_len < self.original_max_seq_len
and self.max_seq_len_cached > self.original_max_seq_len
): # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
# if "dynamic" in self.rope_type:
# self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = (
device_type
if isinstance(device_type, str) and device_type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False):
freqs = (
inv_freq_expanded.float() @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, *args, **kwargs):
logger.warning_once(
"`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
"`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
)
kwargs["rope_type"] = "linear"
super().__init__(*args, **kwargs)
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, *args, **kwargs):
logger.warning_once(
"`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
"`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
"__init__)."
)
kwargs["rope_type"] = "dynamic"
super().__init__(*args, **kwargs)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
)
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=config.mlp_bias
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat(
[
F.linear(x, gate_proj_slices[i])
for i in range(self.config.pretraining_tp)
],
dim=-1,
)
up_proj = torch.cat(
[
F.linear(x, up_proj_slices[i])
for i in range(self.config.pretraining_tp)
],
dim=-1,
)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i])
for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = nn.Linear(
self.hidden_size, self.hidden_size, bias=config.attention_bias
)
# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i])
for i in range(self.config.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i])
for i in range(self.config.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i])
for i in range(self.config.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(
self.hidden_size // self.config.pretraining_tp, dim=2
)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.config.pretraining_tp, dim=1
)
attn_output = sum(
[
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.config.pretraining_tp)
]
)
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LlamaFlashAttention2(LlamaAttention):
"""
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.45
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LlamaSdpaAttention(LlamaAttention):
"""
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from LlamaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
LLAMA_ATTENTION_CLASSES = {
"eager": LlamaAttention,
"flash_attention_2": LlamaFlashAttention2,
"sdpa": LlamaSdpaAttention,
}
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
config=config, layer_idx=layer_idx
)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
LLAMA_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`LlamaConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class LlamaPreTrainedModel(PreTrainedModel):
config_class = LlamaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
LLAMA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not using_static_cache
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError(
"Custom 4D attention mask should be passed in inverted form with max==0`"
)
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(
input_tensor.shape[0], 1, -1, -1
)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = (
causal_mask[:, :, :, :mask_length]
+ attention_mask[:, None, None, :]
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
return causal_mask
class LlamaForCausalLM(LlamaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(
self.vocab_size // self.config.pretraining_tp, dim=0
)
logits = [
F.linear(hidden_states, lm_head_slices[i])
for i in range(self.config.pretraining_tp)
]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
# logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif (
input_ids.shape[1] != cache_position.shape[0]
): # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {
"input_ids": input_ids.contiguous()
} # `contiguous()` needed for compilation use cases
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs
@add_start_docstrings(
"""
The LLaMa Model transformer with a sequence classification head on top (linear layer).
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
LLAMA_START_DOCSTRING,
)
class LlamaForSequenceClassification(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = LlamaModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError(
"Cannot handle batch sizes > 1 if no padding token is defined."
)
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = (
torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
)
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
pooled_logits = logits[
torch.arange(batch_size, device=logits.device), sequence_lengths
]
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(
pooled_logits.view(-1, self.num_labels), labels.view(-1)
)
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""
The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
LLAMA_START_DOCSTRING,
)
class LlamaForQuestionAnswering(LlamaPreTrainedModel):
base_model_prefix = "transformer"
# Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
def __init__(self, config):
super().__init__(config)
self.transformer = LlamaModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.transformer.embed_tokens
def set_input_embeddings(self, value):
self.transformer.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
LLAMA_START_DOCSTRING,
)
class LlamaForTokenClassification(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = LlamaModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
'''
"""
Description :
Author : Boxin Zhang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
"""
from torch import nn
from ktransformers.models.modeling_deepseek import DeepseekV2YarnRotaryEmbedding, DeepseekV2RotaryEmbedding
from transformers import ROPE_INIT_FUNCTIONS
from ktransformers.models.modeling_llama import (
LlamaRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
LlamaDynamicNTKScalingRotaryEmbedding,
)
from ktransformers.models.modeling_deepseek import (
DeepseekV2YarnRotaryEmbedding,
DeepseekV2RotaryEmbedding,
)
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import InferenceState
from transformers.configuration_utils import PretrainedConfig
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.orig_module.__init__(orig_module.dim,
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
)
self.orig_module.__init__(
orig_module.dim, orig_module.max_position_embeddings, orig_module.base
)
self.generate_device = generate_device
self.prefill_device = prefill_device
def load(self):
self.orig_module.__init__(
self.orig_module.dim,
self.orig_module.max_position_embeddings,
self.orig_module.base,
self.device,
)
class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
)
self.orig_module.__init__(
orig_module.dim,
orig_module.max_position_embeddings,
orig_module.base)
orig_module.base,
None,
orig_module.scaling_factor,
orig_module.rope_type,
orig_module.config,
)
self.generate_device = generate_device
self.prefill_device = prefill_device
def load(self):
self.orig_module.__init__(self.orig_module.dim,
self.orig_module.__init__(
self.orig_module.dim,
self.orig_module.max_position_embeddings,
self.orig_module.base,
self.device)
self.device,
self.orig_module.scaling_factor,
self.orig_module.rope_type,
self.orig_module.config,
)
class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.orig_module.__init__(orig_module.dim,
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
)
self.orig_module.__init__(
orig_module.dim,
orig_module.max_position_embeddings,
orig_module.base,
None, #device
None, # device
orig_module.scaling_factor,
orig_module.original_max_position_embeddings,
orig_module.beta_fast,
orig_module.beta_slow,
orig_module.mscale,
orig_module.mscale_all_dim)
orig_module.mscale_all_dim,
)
self.generate_device = generate_device
self.prefill_device = prefill_device
def load(self):
self.orig_module.__init__(self.orig_module.dim,
self.orig_module.__init__(
self.orig_module.dim,
self.orig_module.max_position_embeddings,
self.orig_module.base,
self.generate_device,
......@@ -70,5 +131,42 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
self.orig_module.beta_fast,
self.orig_module.beta_slow,
self.orig_module.mscale,
self.orig_module.mscale_all_dim)
self.orig_module.mscale_all_dim,
)
class DynamicNTKScalingRotaryEmbedding(
BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding
):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, device, **kwargs
)
self.orig_module.__init__(
orig_module.dim,
orig_module.max_position_embeddings,
orig_module.base,
None, # device
orig_module.scaling_factor,
orig_module.rope_type,
orig_module.config,
)
def load(self):
self.orig_module.__init__(
self.orig_module.dim,
self.orig_module.max_position_embeddings,
self.orig_module.base,
self.orig_module.device,
self.orig_module.scaling_factor,
self.orig_module.rope_type,
self.orig_module.config,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment