Unverified Commit c68f367e authored by Daniel Hiltgen's avatar Daniel Hiltgen Committed by GitHub
Browse files

Update GGML to b6646 (#12245)

Notable EOLs with this change:
- MacOS v12 and v13 are no longer supported (v14+ required)
- AMD gfx900 and gfx906 are no longer supported
parent fdb10946
......@@ -16,10 +16,10 @@
static std::string trim(const std::string & str) {
size_t start = 0;
size_t end = str.size();
while (start < end && isspace(str[start])) {
while (start < end && isspace(static_cast<unsigned char>(str[start]))) {
start += 1;
}
while (end > start && isspace(str[end - 1])) {
while (end > start && isspace(static_cast<unsigned char>(str[end - 1]))) {
end -= 1;
}
return str.substr(start, end - start);
......@@ -69,6 +69,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE },
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
};
llm_chat_template llm_chat_template_from_str(const std::string & name) {
......@@ -201,6 +203,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE;
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
return LLM_CHAT_TEMPLATE_KIMI_K2;
} else if (tmpl_contains("<seed:bos>")) {
return LLM_CHAT_TEMPLATE_SEED_OSS;
} else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) {
return LLM_CHAT_TEMPLATE_GROK_2;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
......@@ -752,6 +758,28 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|im_assistant|>assistant<|im_middle|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_SEED_OSS) {
for (auto message: chat) {
std::string role(message->role);
ss << "<seed:bos>" << role << "\n" << (role == "assistant" ? trim(message->content) : message->content) << "<seed:eos>";
}
if (add_ass) {
ss << "<seed:bos>assistant\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GROK_2) {
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "System: " << trim(message->content) << "<|separator|>\n\n";
} else if (role == "user") {
ss << "Human: " << trim(message->content) << "<|separator|>\n\n";
} else if (role == "assistant") {
ss << "Assistant: " << message->content << "<|separator|>\n\n";
}
}
if (add_ass) {
ss << "Assistant:";
}
} else {
// template not supported
return -1;
......
......@@ -49,6 +49,8 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_OPENAI_MOE,
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
LLM_CHAT_TEMPLATE_KIMI_K2,
LLM_CHAT_TEMPLATE_SEED_OSS,
LLM_CHAT_TEMPLATE_GROK_2,
LLM_CHAT_TEMPLATE_UNKNOWN,
};
......
This diff is collapsed.
......@@ -17,9 +17,17 @@ class llama_batch_allocr;
class llama_io_read_i;
class llama_io_write_i;
// "memory" as in abstract memory for the context
struct llama_memory_i;
struct llama_memory_context_i;
// "memory" as in physical memory for a buffer type, in bytes
struct llama_memory_breakdown_data {
size_t model = 0; // memory allocated for the model
size_t context = 0; // memory allocated for the context
size_t compute = 0; // memory allocated for temporary compute buffers
};
struct llama_context {
// init scheduler and compute buffers, reserve worst-case graphs
llama_context(
......@@ -46,10 +54,8 @@ struct llama_context {
llama_memory_t get_memory() const;
// return true of the KV cache was updated
// TODO: remove
bool kv_self_update(bool optimize);
void kv_self_defrag_sched();
// return true if the memory was updated
bool memory_update(bool optimize);
enum llama_pooling_type pooling_type() const;
......@@ -111,9 +117,9 @@ struct llama_context {
size_t state_get_data( uint8_t * dst, size_t size);
size_t state_set_data(const uint8_t * src, size_t size);
size_t state_seq_get_size(llama_seq_id seq_id);
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags);
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags);
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags);
bool state_load_file(
const char * filepath,
......@@ -146,12 +152,15 @@ struct llama_context {
llama_perf_context_data perf_get_data() const;
void perf_reset();
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown() const;
//
// training
//
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
// TODO: more flexible combinations of logical/physical batch size and context size
void opt_epoch(
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
......@@ -197,7 +206,7 @@ public:
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
// reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
private:
llm_graph_params graph_params(
......@@ -212,8 +221,8 @@ private:
size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io);
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
//
// members
......@@ -229,9 +238,6 @@ private:
std::unique_ptr<llama_memory_i> memory;
// TODO: temporary, until the llama_kv_self_defrag() API is removed
bool memory_force_optimize = false;
// decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr;
......@@ -287,10 +293,6 @@ private:
bool has_evaluated_once = false;
// env: LLAMA_SET_ROWS (temporary)
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
bool supports_set_rows = true;
// env: LLAMA_GRAPH_REUSE_DISABLE
bool graph_reuse_disable = false;
......
......@@ -4,7 +4,7 @@
#include <cstdint>
#define LLAMA_MAX_SEQ 64
#define LLAMA_MAX_SEQ 256
struct llama_cparams {
uint32_t n_ctx; // context size used during inference
......@@ -24,7 +24,6 @@ struct llama_cparams {
float yarn_attn_factor;
float yarn_beta_fast;
float yarn_beta_slow;
float defrag_thold;
bool embeddings;
bool causal_attn;
......
This diff is collapsed.
......@@ -19,8 +19,8 @@ struct llama_cparams;
struct llama_memory_context_i;
class llama_kv_cache_unified_context;
class llama_kv_cache_unified_iswa_context;
class llama_kv_cache_context;
class llama_kv_cache_iswa_context;
class llama_memory_recurrent_context;
class llama_memory_hybrid_context;
......@@ -78,6 +78,11 @@ struct llm_graph_params;
class llm_graph_input_i {
public:
llm_graph_input_i() {
const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
}
virtual ~llm_graph_input_i() = default;
virtual void set_input(const llama_ubatch * ubatch) = 0;
......@@ -90,6 +95,9 @@ public:
GGML_UNUSED(params);
return false;
}
protected:
// env: LLAMA_GRAPH_INPUT_DEBUG
int debug = 0;
};
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
......@@ -152,7 +160,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
public:
llm_graph_input_pos_bucket_kv(
const llama_hparams & hparams,
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
virtual ~llm_graph_input_pos_bucket_kv() = default;
void set_input(const llama_ubatch * ubatch) override;
......@@ -161,7 +169,7 @@ public:
const llama_hparams hparams;
const llama_kv_cache_unified_context * mctx;
const llama_kv_cache_context * mctx;
};
class llm_graph_input_out_ids : public llm_graph_input_i {
......@@ -198,7 +206,7 @@ public:
class llm_graph_input_cls : public llm_graph_input_i {
public:
llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
virtual ~llm_graph_input_cls() = default;
void set_input(const llama_ubatch * ubatch) override;
......@@ -206,6 +214,7 @@ public:
ggml_tensor * cls; // I32 [n_batch]
const llama_cparams cparams;
const llm_arch arch;
};
class llm_graph_input_rs : public llm_graph_input_i {
......@@ -257,17 +266,17 @@ public:
const llama_cparams cparams;
};
class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
class llm_graph_input_attn_kv : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_unified(
llm_graph_input_attn_kv(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_context * mctx) :
const llama_kv_cache_context * mctx) :
hparams(hparams),
cparams(cparams),
mctx(mctx) {
}
~llm_graph_input_attn_kv_unified() = default;
~llm_graph_input_attn_kv() = default;
void set_input(const llama_ubatch * ubatch) override;
......@@ -290,20 +299,20 @@ public:
const llama_hparams hparams;
const llama_cparams cparams;
const llama_kv_cache_unified_context * mctx;
const llama_kv_cache_context * mctx;
};
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_unified_iswa(
llm_graph_input_attn_kv_iswa(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_iswa_context * mctx) :
const llama_kv_cache_iswa_context * mctx) :
hparams(hparams),
cparams(cparams),
mctx(mctx) {
}
~llm_graph_input_attn_kv_unified_iswa() = default;
~llm_graph_input_attn_kv_iswa() = default;
void set_input(const llama_ubatch * ubatch) override;
......@@ -330,7 +339,7 @@ public:
const llama_hparams hparams;
const llama_cparams cparams;
const llama_kv_cache_unified_iswa_context * mctx;
const llama_kv_cache_iswa_context * mctx;
};
class llm_graph_input_attn_cross : public llm_graph_input_i {
......@@ -351,7 +360,7 @@ public:
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
public:
llm_graph_input_mem_hybrid(
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_context * mctx) :
inp_attn(std::move(inp_attn)),
......@@ -361,11 +370,11 @@ public:
void set_input(const llama_ubatch * ubatch) override;
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs;
std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs;
llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
const llama_memory_hybrid_context * mctx;
};
......@@ -680,14 +689,15 @@ struct llm_graph_context {
//
ggml_tensor * build_attn_mha(
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
ggml_tensor * sinks,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale) const;
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
......@@ -699,50 +709,39 @@ struct llm_graph_context {
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
llm_graph_input_attn_kv * build_attn_inp_kv() const;
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified * inp,
llm_graph_input_attn_kv * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
// TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
ggml_tensor * build_attn_with_sinks(
llm_graph_input_attn_kv_unified_iswa * inp,
llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
......@@ -756,6 +755,7 @@ struct llm_graph_context {
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
......@@ -765,7 +765,7 @@ struct llm_graph_context {
//
// TODO: move this implementation to llama_memory_recurrent.
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
// this is analogous to llama_kv_cache::cpy_k / cpy_v
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
// `llama_memory_recurrent`
......
#include "llama-hparams.h"
#include "ggml.h"
#include <cassert>
void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
if (dense_first) {
......@@ -161,3 +162,64 @@ bool llama_hparams::is_swa(uint32_t il) const {
GGML_ABORT("fatal error");
}
bool llama_hparams::has_kv(uint32_t il) const {
if (n_layer_kv_from_start >= 0) {
if (il < (uint32_t) n_layer_kv_from_start) {
return true;
}
return false;
}
// by default, all layers have kv
return true;
}
uint32_t llama_hparams::n_layer_kv() const {
uint32_t res = 0;
for (uint32_t il = 0; il < n_layer; ++il) {
if (has_kv(il)) {
res++;
}
}
return res;
}
bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
assert(p0 >= 0 && p1 >= 0);
switch (swa_type) {
case LLAMA_SWA_TYPE_NONE:
{
} break;
case LLAMA_SWA_TYPE_STANDARD:
{
if (p1 - p0 >= (int32_t) n_swa) {
return true;
}
} break;
case LLAMA_SWA_TYPE_CHUNKED:
{
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
if (p0 < pos_chunk_start) {
return true;
}
} break;
case LLAMA_SWA_TYPE_SYMMETRIC:
{
const int32_t half_n_swa = (int32_t) n_swa / 2;
const int32_t pos_diff = p1 - p0;
// Mask if outside the symmetric window
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
return true;
}
} break;
}
return false;
}
......@@ -16,9 +16,10 @@ enum llama_expert_gating_func_type {
};
enum llama_swa_type {
LLAMA_SWA_TYPE_NONE = 0,
LLAMA_SWA_TYPE_STANDARD = 1,
LLAMA_SWA_TYPE_CHUNKED = 2,
LLAMA_SWA_TYPE_NONE = 0,
LLAMA_SWA_TYPE_STANDARD = 1,
LLAMA_SWA_TYPE_CHUNKED = 2,
LLAMA_SWA_TYPE_SYMMETRIC = 3,
};
struct llama_hparams_posnet {
......@@ -41,6 +42,7 @@ struct llama_hparams {
uint32_t n_embd;
uint32_t n_embd_features = 0;
uint32_t n_layer;
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
uint32_t n_rot;
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
......@@ -69,10 +71,13 @@ struct llama_hparams {
uint32_t n_lora_kv = 0;
uint32_t n_ff_exp = 0;
uint32_t n_ff_shexp = 0;
uint32_t n_ff_chexp = 0;
uint32_t n_expert_shared = 0;
uint32_t n_norm_groups = 0;
uint32_t n_group_experts = 0;
float expert_weights_scale = 0.0;
float expert_group_scale = 0.05f;
float expert_weights_scale = 0.0f;
bool expert_weights_norm = false;
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
uint32_t moe_every_n_layers = 0;
......@@ -82,8 +87,9 @@ struct llama_hparams {
float f_norm_rms_eps;
float f_norm_group_eps;
float f_attn_logit_softcapping = 50.0f;
float f_final_logit_softcapping = 30.0f;
float f_attn_logit_softcapping = 50.0f;
float f_router_logit_softcapping = 30.0f;
float f_final_logit_softcapping = 30.0f;
// for RWKV
uint32_t rescale_every_n_layers = 0;
......@@ -104,6 +110,11 @@ struct llama_hparams {
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul = 0.0f;
float yarn_ext_factor = -1.0f;
float yarn_attn_factor = 1.0f;
float yarn_beta_fast = 32.0f;
float yarn_beta_slow = 1.0f;
std::array<int, 4> rope_sections;
// Sliding Window Attention (SWA)
......@@ -136,10 +147,14 @@ struct llama_hparams {
float f_embedding_scale = 0.0f;
float f_attention_scale = 0.0f;
// grok-2
float f_attn_out_scale = 0.0f;
uint32_t attn_temp_length = 0;
bool causal_attn = true;
bool use_alibi = false;
bool attn_soft_cap = false;
bool use_kq_norm = true;
bool use_kq_norm = false;
// for Classifiers
uint32_t n_cls_out = 1;
......@@ -159,6 +174,7 @@ struct llama_hparams {
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
uint32_t dec_n_layer = 0;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
......@@ -226,6 +242,16 @@ struct llama_hparams {
bool n_bskcn(uint32_t n, uint32_t il) const;
bool is_swa(uint32_t il) const;
bool has_kv(uint32_t il) const;
// number of layers for which has_kv() returns true
uint32_t n_layer_kv() const;
// note that this function uses different SWA parameters from those in the hparams
// TODO: think of a better place for this function
// TODO: pack the SWA params in a struct?
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
};
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
......
......@@ -59,3 +59,5 @@ std::string llama_format_tensor_shape(const std::vector<int64_t> & ne);
std::string llama_format_tensor_shape(const struct ggml_tensor * t);
std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i);
#define LLAMA_TENSOR_NAME_FATTN "__fattn__"
#include "llama-kv-cache-unified-iswa.h"
#include "llama-kv-cache-iswa.h"
#include "llama-impl.h"
#include "llama-batch.h"
......@@ -8,10 +8,10 @@
#include <cassert>
//
// llama_kv_cache_unified_iswa
// llama_kv_cache_iswa
//
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
llama_kv_cache_iswa::llama_kv_cache_iswa(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
......@@ -22,9 +22,26 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
uint32_t n_pad,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
// chain filters
const layer_filter_cb filter_base = [&](int32_t il) {
if (filter && !filter(il)) {
return false;
}
return !model.hparams.is_swa(il);
};
const layer_filter_cb filter_swa = [&](int32_t il) {
if (filter && !filter(il)) {
return false;
}
return model.hparams.is_swa(il);
};
const uint32_t size_base = kv_size;
......@@ -40,25 +57,25 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
kv_base = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_base), type_k, type_v,
kv_base = std::make_unique<llama_kv_cache>(
model, type_k, type_v,
v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE);
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_swa), type_k, type_v,
kv_swa = std::make_unique<llama_kv_cache>(
model, type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type);
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
}
void llama_kv_cache_unified_iswa::clear(bool data) {
void llama_kv_cache_iswa::clear(bool data) {
kv_base->clear(data);
kv_swa ->clear(data);
}
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
bool res = true;
res = res & kv_base->seq_rm(seq_id, p0, p1);
......@@ -67,36 +84,44 @@ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llam
return res;
}
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
kv_base->seq_keep(seq_id);
kv_swa ->seq_keep(seq_id);
}
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
kv_base->seq_add(seq_id, p0, p1, shift);
kv_swa ->seq_add(seq_id, p0, p1, shift);
}
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
kv_base->seq_div(seq_id, p0, p1, d);
kv_swa ->seq_div(seq_id, p0, p1, d);
}
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
return kv_swa->seq_pos_min(seq_id);
}
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
return kv_swa->seq_pos_max(seq_id);
}
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_iswa::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> mb = kv_base->memory_breakdown();
for (const auto & buft_size : kv_swa->memory_breakdown()) {
mb[buft_size.first] += buft_size.second;
}
return mb;
}
llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
GGML_UNUSED(embd_all);
// first try simple split
......@@ -136,7 +161,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
assert(sinfos_base.size() == sinfos_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_context>(
return std::make_unique<llama_kv_cache_iswa_context>(
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false);
......@@ -172,61 +197,67 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
assert(sinfos_base.size() == sinfos_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_context>(
return std::make_unique<llama_kv_cache_iswa_context>(
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false);
// TODO: if we fail again, we should attempt different splitting strategies
// but to do that properly, we first have to refactor the batches to be more flexible
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
return std::make_unique<llama_kv_cache_iswa_context>(this);
}
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
}
bool llama_kv_cache_unified_iswa::get_can_shift() const {
bool llama_kv_cache_iswa::get_can_shift() const {
return kv_base->get_size() == kv_swa->get_size();
}
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
kv_base->state_write(io, seq_id);
kv_swa ->state_write(io, seq_id);
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
kv_base->state_write(io, seq_id, flags);
}
kv_swa->state_write(io, seq_id, flags);
}
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
kv_base->state_read(io, seq_id);
kv_swa ->state_read(io, seq_id);
void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
kv_base->state_read(io, seq_id, flags);
}
kv_swa->state_read(io, seq_id, flags);
}
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
llama_kv_cache * llama_kv_cache_iswa::get_base() const {
return kv_base.get();
}
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
return kv_swa.get();
}
//
// llama_kv_cache_unified_iswa_context
// llama_kv_cache_iswa_context
//
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv) :
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
llama_kv_cache_iswa * kv) :
ctx_base(kv->get_base()->init_full()),
ctx_swa (kv->get_swa ()->init_full()),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
llama_kv_cache_iswa * kv,
llama_context * lctx,
bool optimize) :
ctx_base(kv->get_base()->init_update(lctx, optimize)),
......@@ -234,21 +265,21 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
llama_kv_cache_iswa * kv,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}
llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
bool llama_kv_cache_unified_iswa_context::next() {
bool llama_kv_cache_iswa_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
ctx_base->next();
......@@ -261,7 +292,7 @@ bool llama_kv_cache_unified_iswa_context::next() {
return true;
}
bool llama_kv_cache_unified_iswa_context::apply() {
bool llama_kv_cache_iswa_context::apply() {
assert(!llama_memory_status_is_fail(status));
bool res = true;
......@@ -272,24 +303,24 @@ bool llama_kv_cache_unified_iswa_context::apply() {
return res;
}
llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
llama_memory_status llama_kv_cache_iswa_context::get_status() const {
return status;
}
const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next];
}
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
return static_cast<const llama_kv_cache_context *>(ctx_base.get());
}
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
}
#pragma once
#include "llama-kv-cache-unified.h"
#include "llama-kv-cache.h"
#include <vector>
//
// llama_kv_cache_unified_iswa
// llama_kv_cache_iswa
//
// utilizes two instances of llama_kv_cache_unified
// utilizes two instances of llama_kv_cache
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
class llama_kv_cache_unified_iswa : public llama_memory_i {
class llama_kv_cache_iswa : public llama_memory_i {
public:
llama_kv_cache_unified_iswa(
llama_kv_cache_iswa(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
......@@ -24,9 +24,11 @@ public:
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad);
uint32_t n_pad,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
~llama_kv_cache_unified_iswa() = default;
~llama_kv_cache_iswa() = default;
//
// llama_memory_i
......@@ -54,52 +56,54 @@ public:
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_kv_cache_unified_iswa specific API
// llama_kv_cache_iswa specific API
//
llama_kv_cache_unified * get_base() const;
llama_kv_cache_unified * get_swa () const;
llama_kv_cache * get_base() const;
llama_kv_cache * get_swa () const;
private:
const llama_hparams & hparams;
const bool unified;
std::unique_ptr<llama_kv_cache_unified> kv_base;
std::unique_ptr<llama_kv_cache_unified> kv_swa;
std::unique_ptr<llama_kv_cache> kv_base;
std::unique_ptr<llama_kv_cache> kv_swa;
};
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
class llama_kv_cache_iswa_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
// used for errors
llama_kv_cache_unified_iswa_context(llama_memory_status status);
llama_kv_cache_iswa_context(llama_memory_status status);
// used to create a full-cache context
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv);
llama_kv_cache_iswa_context(
llama_kv_cache_iswa * kv);
// used to create an update context
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
llama_kv_cache_iswa_context(
llama_kv_cache_iswa * kv,
llama_context * lctx,
bool optimize);
// used to create a batch processing context from a batch
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
llama_kv_cache_iswa_context(
llama_kv_cache_iswa * kv,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_iswa_context();
virtual ~llama_kv_cache_iswa_context();
//
// llama_memory_context_i
......@@ -112,14 +116,14 @@ public:
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_unified_iswa_context specific API
// llama_kv_cache_iswa_context specific API
//
const llama_kv_cache_unified_context * get_base() const;
const llama_kv_cache_unified_context * get_swa() const;
const llama_kv_cache_context * get_base() const;
const llama_kv_cache_context * get_swa() const;
private:
//llama_kv_cache_unified_iswa * kv;
//llama_kv_cache_iswa * kv;
// the index of the next ubatch to process
size_t i_next = 0;
......
......@@ -14,27 +14,13 @@ struct llama_model;
struct llama_context;
//
// llama_kv_cache_unified
// llama_kv_cache
//
class llama_kv_cache_unified : public llama_memory_i {
class llama_kv_cache : public llama_memory_i {
public:
static uint32_t get_padding(const llama_cparams & cparams);
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
struct defrag_info {
bool empty() const {
return ids.empty();
}
// contains information about which cell moves where:
// - cell i moves to ids[i]
// - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
std::vector<uint32_t> ids;
};
struct stream_copy_info {
bool empty() const {
assert(ssrc.size() == sdst.size());
......@@ -52,8 +38,8 @@ public:
using idx_vec_t = std::vector<uint32_t>;
// number of streams: ns = s1 - s0 + 1
llama_seq_id s0;
llama_seq_id s1;
uint32_t s0;
uint32_t s1;
std::vector<llama_seq_id> strm; // [ns]
std::vector<idx_vec_t> idxs; // [ns]
......@@ -92,21 +78,22 @@ public:
using slot_info_vec_t = std::vector<slot_info>;
llama_kv_cache_unified(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type);
~llama_kv_cache_unified() = default;
llama_kv_cache(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
~llama_kv_cache() = default;
//
// llama_memory_i
......@@ -134,13 +121,15 @@ public:
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_kv_cache_unified specific API
// llama_kv_cache specific API
//
uint32_t get_size() const;
......@@ -152,10 +141,7 @@ public:
// graph_build API
//
uint32_t get_n_kv() const;
// TODO: temporary
bool get_supports_set_rows() const;
uint32_t get_n_kv(const slot_info & sinfo) const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
......@@ -173,7 +159,7 @@ public:
// return empty vector on failure
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info);
// find a slot of kv cells that can hold the ubatch
// if cont == true, then the slot must be continuous
......@@ -228,10 +214,7 @@ private:
// env: LLAMA_KV_CACHE_DEBUG
int debug = 0;
// env: LLAMA_SET_ROWS (temporary)
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
bool supports_set_rows = true;
// this is the SWA type of the cache - not to be confused with the model SWA type
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
std::vector<ggml_context_ptr> ctxs;
......@@ -241,7 +224,7 @@ private:
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
std::vector<uint32_t> v_heads;
std::vector<llama_kv_cells_unified> v_cells;
std::vector<llama_kv_cells> v_cells;
// maps from a sequence id to a stream id
std::vector<uint32_t> seq_to_stream;
......@@ -254,9 +237,6 @@ private:
// model layer id -> KV cache layer id
std::unordered_map<int32_t, int32_t> map_layer_ids;
// return non-empty vector if cells have been moved
defrag_info defrag_prepare(int32_t n_max_nodes) const;
size_t total_size() const;
size_t size_k_bytes() const;
......@@ -277,11 +257,6 @@ private:
llm_graph_result * res,
llama_context * lctx) const;
ggml_cgraph * build_graph_defrag(
llm_graph_result * res,
llama_context * lctx,
const defrag_info & dinfo) const;
struct cell_ranges_t {
uint32_t strm;
......@@ -295,35 +270,33 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
};
class llama_kv_cache_unified_context : public llama_memory_context_i {
class llama_kv_cache_context : public llama_memory_context_i {
public:
// some shorthands
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
using defrag_info = llama_kv_cache_unified::defrag_info;
using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
using stream_copy_info = llama_kv_cache::stream_copy_info;
// used for errors
llama_kv_cache_unified_context(llama_memory_status status);
llama_kv_cache_context(llama_memory_status status);
// used to create a full-cache context
llama_kv_cache_unified_context(
llama_kv_cache_unified * kv);
llama_kv_cache_context(
llama_kv_cache * kv);
// used to create an update context
llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
llama_kv_cache_context(
llama_kv_cache * kv,
llama_context * lctx,
bool do_shift,
defrag_info dinfo,
stream_copy_info sc_info);
// used to create a batch procesing context from a batch
llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
llama_kv_cache_context(
llama_kv_cache * kv,
slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_context();
virtual ~llama_kv_cache_context();
//
// llama_memory_context_i
......@@ -336,22 +309,27 @@ public:
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_unified_context specific API
// llama_kv_cache_context specific API
//
uint32_t get_n_kv() const;
// TODO: temporary
bool get_supports_set_rows() const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
// store k_cur and v_cur in the cache based on the provided head location
// note: the heads in k_cur and v_cur should be layed out contiguously in memory
// - k_cur [n_embd_head_k, n_head_k, n_tokens]
// - k_idxs [n_tokens]
// - v_cur [n_embd_head_v, n_head_v, n_tokens]
// - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
// create destination indices for each head of the current batch for where it would be written in the KV cache
// the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but
// helps understand the implementation logic of cpy_k and cpy_v
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
......@@ -365,7 +343,7 @@ public:
private:
llama_memory_status status;
llama_kv_cache_unified * kv;
llama_kv_cache * kv;
llama_context * lctx;
//
......@@ -374,8 +352,6 @@ private:
bool do_shift = false;
defrag_info dinfo;
stream_copy_info sc_info;
//
......
......@@ -11,7 +11,7 @@
// meta information about KV cells that can be part of multiple sequences at the same time
// TODO: add unit tests
class llama_kv_cells_unified {
class llama_kv_cells {
public:
void reset() {
for (uint32_t i = 0; i < pos.size(); ++i) {
......@@ -77,30 +77,30 @@ public:
}
// move cell isrc to idst (used during defrag)
void mv(uint32_t isrc, uint32_t idst) {
assert(isrc < pos.size());
assert(idst < pos.size());
//void mv(uint32_t isrc, uint32_t idst) {
// assert(isrc < pos.size());
// assert(idst < pos.size());
assert(pos[idst] == -1);
assert(pos[isrc] != -1);
// assert(pos[idst] == -1);
// assert(pos[isrc] != -1);
pos [idst] = pos [isrc];
shift[idst] = shift[isrc];
seq [idst] = seq [isrc];
// pos [idst] = pos [isrc];
// shift[idst] = shift[isrc];
// seq [idst] = seq [isrc];
pos [isrc] = -1;
shift[isrc] = 0;
seq [isrc].reset();
// pos [isrc] = -1;
// shift[isrc] = 0;
// seq [isrc].reset();
used.erase (isrc);
used.insert(idst);
}
// used.erase (isrc);
// used.insert(idst);
//}
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
llama_kv_cells cp(uint32_t i, uint32_t n) const {
assert(i + n <= pos.size());
llama_kv_cells_unified res;
llama_kv_cells res;
res.resize(n);
......@@ -117,8 +117,8 @@ public:
}
// copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
llama_kv_cells_unified res;
llama_kv_cells cp(const std::vector<uint32_t> & idxs) const {
llama_kv_cells res;
res.resize(idxs.size());
......@@ -135,7 +135,7 @@ public:
}
// set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
void set(uint32_t i, const llama_kv_cells_unified & other) {
void set(uint32_t i, const llama_kv_cells & other) {
assert(i + other.pos.size() <= pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) {
......@@ -165,7 +165,7 @@ public:
}
// set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells & other) {
assert(idxs.size() == other.pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) {
......
......@@ -9,32 +9,29 @@
//
llama_memory_hybrid::llama_memory_hybrid(
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
layer_filter_cb && filter_attn,
layer_filter_cb && filter_recr) :
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
const layer_filter_cb & filter_attn,
const layer_filter_cb & filter_recr) :
hparams(model.hparams),
mem_attn(new llama_kv_cache_unified(
mem_attn(new llama_kv_cache(
model,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
type_k,
type_v,
v_trans,
......@@ -44,18 +41,22 @@ llama_memory_hybrid::llama_memory_hybrid(
n_seq_max,
n_pad,
n_swa,
swa_type
swa_type,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
nullptr
)),
mem_recr(new llama_memory_recurrent(
model,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr,
type_r,
type_s,
offload,
rs_size,
n_seq_max
n_seq_max,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr
)) {}
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
......@@ -165,17 +166,29 @@ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
}
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> mb = mem_attn->memory_breakdown();
for (const auto & buft_size : mem_recr->memory_breakdown()) {
mb[buft_size.first] += buft_size.second;
}
return mb;
}
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);
mem_attn->state_write(io, seq_id);
mem_recr->state_write(io, seq_id);
}
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);
mem_attn->state_read(io, seq_id);
mem_recr->state_read(io, seq_id);
}
llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
return mem_attn.get();
}
......@@ -206,7 +219,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}
......@@ -244,8 +257,8 @@ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
return ubatches[i_next];
}
const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
const llama_kv_cache_context * llama_memory_hybrid_context::get_attn() const {
return static_cast<const llama_kv_cache_context *>(ctx_attn.get());
}
const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
......
......@@ -2,7 +2,7 @@
#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-kv-cache-unified.h"
#include "llama-kv-cache.h"
#include "llama-memory.h"
#include "llama-memory-recurrent.h"
......@@ -13,36 +13,32 @@
// llama_memory_hybrid
//
// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
// utilizes instances of llama_memory_recurrent and llama_kv_cache to
// support models where each layer may be either attention-based or recurrent
class llama_memory_hybrid : public llama_memory_i {
public:
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
llama_memory_hybrid(
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
layer_filter_cb && filter_attn = nullptr,
layer_filter_cb && filter_recr = nullptr);
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
const layer_filter_cb & filter_attn = nullptr,
const layer_filter_cb & filter_recr = nullptr);
~llama_memory_hybrid() = default;
......@@ -72,28 +68,30 @@ public:
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_memory_hybrid specific API
//
llama_kv_cache_unified * get_mem_attn() const;
llama_kv_cache * get_mem_attn() const;
llama_memory_recurrent * get_mem_recr() const;
private:
const llama_hparams & hparams;
const std::unique_ptr<llama_kv_cache_unified> mem_attn;
const std::unique_ptr<llama_kv_cache> mem_attn;
const std::unique_ptr<llama_memory_recurrent> mem_recr;
};
class llama_memory_hybrid_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
// init failure
explicit llama_memory_hybrid_context(llama_memory_status status);
......@@ -125,7 +123,7 @@ public:
// llama_memory_hybrid_context
//
const llama_kv_cache_unified_context * get_attn() const;
const llama_kv_cache_context * get_attn() const;
const llama_memory_recurrent_context * get_recr() const;
private:
......
......@@ -16,13 +16,13 @@
//
llama_memory_recurrent::llama_memory_recurrent(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_r,
ggml_type type_s,
bool offload,
uint32_t mem_size,
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
const llama_model & model,
ggml_type type_r,
ggml_type type_s,
bool offload,
uint32_t mem_size,
uint32_t n_seq_max,
const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
const int32_t n_layer = hparams.n_layer;
head = 0;
......@@ -359,6 +359,14 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
return result;
}
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
}
return ret;
}
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
do {
balloc.split_reset();
......@@ -680,7 +688,9 @@ size_t llama_memory_recurrent::size_s_bytes() const {
return size_s_bytes;
}
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0;
......@@ -718,7 +728,9 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
state_write_data(io, cell_ranges);
}
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);
uint32_t cell_count;
io.read_to(&cell_count, sizeof(cell_count));
......
......@@ -4,6 +4,7 @@
#include "llama-graph.h"
#include "llama-memory.h"
#include <map>
#include <set>
#include <vector>
......@@ -12,21 +13,17 @@
//
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
// see the implementation of llama_kv_cache_unified_context_i for an example how to do it
// see the implementation of llama_kv_cache_context_i for an example how to do it
class llama_memory_recurrent : public llama_memory_i {
public:
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
llama_memory_recurrent(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_r,
ggml_type type_s,
bool offload,
uint32_t mem_size,
uint32_t n_seq_max);
const llama_model & model,
ggml_type type_r,
ggml_type type_s,
bool offload,
uint32_t mem_size,
uint32_t n_seq_max,
const layer_filter_cb & filter);
~llama_memory_recurrent() = default;
......@@ -54,6 +51,8 @@ public:
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
bool prepare(const std::vector<llama_ubatch> & ubatches);
// find a contiguous slot of memory cells and emplace the ubatch there
......@@ -63,8 +62,8 @@ public:
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
......
......@@ -2,7 +2,9 @@
#include "llama.h"
#include <map>
#include <memory>
#include <functional>
struct llama_ubatch;
......@@ -36,8 +38,8 @@ bool llama_memory_status_is_fail(llama_memory_status status);
// the interface for managing the memory context during batch processing
// this interface is implemented per memory type. see:
// - llama_kv_cache_unified_context
// - llama_kv_cache_unified_iswa_context
// - llama_kv_cache_context
// - llama_kv_cache_iswa_context
// ...
//
// the only method that should mutate the memory and the memory context is llama_memory_i::apply()
......@@ -64,6 +66,13 @@ using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
// general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types
struct llama_memory_i {
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
// this callback is used to specify which layers should reuse memory from other layers
// return negative value to indicate that the layer il should not reuse memory
using layer_reuse_cb = std::function<int32_t(int32_t il)>;
virtual ~llama_memory_i() = default;
// split the input batch into a set of ubatches and verify that they can fit into the cache
......@@ -77,7 +86,7 @@ struct llama_memory_i {
// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_context_ptr init_full() = 0;
// prepare for any pending memory updates, such as shifts, defrags, etc.
// prepare for any pending memory updates, such as shifts, copies, etc.
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
......@@ -100,17 +109,14 @@ struct llama_memory_i {
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
virtual std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const = 0;
//
// state write/read
//
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
};
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
// TODO: temporary until the llama_kv_cache is removed from the public API
struct llama_kv_cache : public llama_memory_i {
virtual ~llama_kv_cache() = default;
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment