Unverified Commit c68f367e authored by Daniel Hiltgen's avatar Daniel Hiltgen Committed by GitHub
Browse files

Update GGML to b6646 (#12245)

Notable EOLs with this change:
- MacOS v12 and v13 are no longer supported (v14+ required)
- AMD gfx900 and gfx906 are no longer supported
parent fdb10946
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
static std::string trim(const std::string & str) { static std::string trim(const std::string & str) {
size_t start = 0; size_t start = 0;
size_t end = str.size(); size_t end = str.size();
while (start < end && isspace(str[start])) { while (start < end && isspace(static_cast<unsigned char>(str[start]))) {
start += 1; start += 1;
} }
while (end > start && isspace(str[end - 1])) { while (end > start && isspace(static_cast<unsigned char>(str[end - 1]))) {
end -= 1; end -= 1;
} }
return str.substr(start, end - start); return str.substr(start, end - start);
...@@ -69,6 +69,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = { ...@@ -69,6 +69,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE }, { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE },
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
}; };
llm_chat_template llm_chat_template_from_str(const std::string & name) { llm_chat_template llm_chat_template_from_str(const std::string & name) {
...@@ -201,6 +203,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { ...@@ -201,6 +203,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE; return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE;
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
return LLM_CHAT_TEMPLATE_KIMI_K2; return LLM_CHAT_TEMPLATE_KIMI_K2;
} else if (tmpl_contains("<seed:bos>")) {
return LLM_CHAT_TEMPLATE_SEED_OSS;
} else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) {
return LLM_CHAT_TEMPLATE_GROK_2;
} }
return LLM_CHAT_TEMPLATE_UNKNOWN; return LLM_CHAT_TEMPLATE_UNKNOWN;
} }
...@@ -752,6 +758,28 @@ int32_t llm_chat_apply_template( ...@@ -752,6 +758,28 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "<|im_assistant|>assistant<|im_middle|>"; ss << "<|im_assistant|>assistant<|im_middle|>";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_SEED_OSS) {
for (auto message: chat) {
std::string role(message->role);
ss << "<seed:bos>" << role << "\n" << (role == "assistant" ? trim(message->content) : message->content) << "<seed:eos>";
}
if (add_ass) {
ss << "<seed:bos>assistant\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GROK_2) {
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "System: " << trim(message->content) << "<|separator|>\n\n";
} else if (role == "user") {
ss << "Human: " << trim(message->content) << "<|separator|>\n\n";
} else if (role == "assistant") {
ss << "Assistant: " << message->content << "<|separator|>\n\n";
}
}
if (add_ass) {
ss << "Assistant:";
}
} else { } else {
// template not supported // template not supported
return -1; return -1;
......
...@@ -49,6 +49,8 @@ enum llm_chat_template { ...@@ -49,6 +49,8 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_OPENAI_MOE, LLM_CHAT_TEMPLATE_OPENAI_MOE,
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_KIMI_K2,
LLM_CHAT_TEMPLATE_SEED_OSS,
LLM_CHAT_TEMPLATE_GROK_2,
LLM_CHAT_TEMPLATE_UNKNOWN, LLM_CHAT_TEMPLATE_UNKNOWN,
}; };
......
This diff is collapsed.
...@@ -17,9 +17,17 @@ class llama_batch_allocr; ...@@ -17,9 +17,17 @@ class llama_batch_allocr;
class llama_io_read_i; class llama_io_read_i;
class llama_io_write_i; class llama_io_write_i;
// "memory" as in abstract memory for the context
struct llama_memory_i; struct llama_memory_i;
struct llama_memory_context_i; struct llama_memory_context_i;
// "memory" as in physical memory for a buffer type, in bytes
struct llama_memory_breakdown_data {
size_t model = 0; // memory allocated for the model
size_t context = 0; // memory allocated for the context
size_t compute = 0; // memory allocated for temporary compute buffers
};
struct llama_context { struct llama_context {
// init scheduler and compute buffers, reserve worst-case graphs // init scheduler and compute buffers, reserve worst-case graphs
llama_context( llama_context(
...@@ -46,10 +54,8 @@ struct llama_context { ...@@ -46,10 +54,8 @@ struct llama_context {
llama_memory_t get_memory() const; llama_memory_t get_memory() const;
// return true of the KV cache was updated // return true if the memory was updated
// TODO: remove bool memory_update(bool optimize);
bool kv_self_update(bool optimize);
void kv_self_defrag_sched();
enum llama_pooling_type pooling_type() const; enum llama_pooling_type pooling_type() const;
...@@ -111,9 +117,9 @@ struct llama_context { ...@@ -111,9 +117,9 @@ struct llama_context {
size_t state_get_data( uint8_t * dst, size_t size); size_t state_get_data( uint8_t * dst, size_t size);
size_t state_set_data(const uint8_t * src, size_t size); size_t state_set_data(const uint8_t * src, size_t size);
size_t state_seq_get_size(llama_seq_id seq_id); size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags);
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size); size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags);
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size); size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags);
bool state_load_file( bool state_load_file(
const char * filepath, const char * filepath,
...@@ -146,12 +152,15 @@ struct llama_context { ...@@ -146,12 +152,15 @@ struct llama_context {
llama_perf_context_data perf_get_data() const; llama_perf_context_data perf_get_data() const;
void perf_reset(); void perf_reset();
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown() const;
// //
// training // training
// //
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params); void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
// TODO: more flexible combinations of logical/physical batch size and context size
void opt_epoch( void opt_epoch(
ggml_opt_dataset_t dataset, ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train, ggml_opt_result_t result_train,
...@@ -197,7 +206,7 @@ public: ...@@ -197,7 +206,7 @@ public:
ggml_status graph_compute(ggml_cgraph * gf, bool batched); ggml_status graph_compute(ggml_cgraph * gf, bool batched);
// reserve a graph with a dummy ubatch of the specified size // reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
private: private:
llm_graph_params graph_params( llm_graph_params graph_params(
...@@ -212,8 +221,8 @@ private: ...@@ -212,8 +221,8 @@ private:
size_t state_write_data(llama_io_write_i & io); size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io); size_t state_read_data (llama_io_read_i & io);
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id); size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id); size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
// //
// members // members
...@@ -229,9 +238,6 @@ private: ...@@ -229,9 +238,6 @@ private:
std::unique_ptr<llama_memory_i> memory; std::unique_ptr<llama_memory_i> memory;
// TODO: temporary, until the llama_kv_self_defrag() API is removed
bool memory_force_optimize = false;
// decode output (2-dimensional array: [n_outputs][n_vocab]) // decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr; float * logits = nullptr;
...@@ -287,10 +293,6 @@ private: ...@@ -287,10 +293,6 @@ private:
bool has_evaluated_once = false; bool has_evaluated_once = false;
// env: LLAMA_SET_ROWS (temporary)
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
bool supports_set_rows = true;
// env: LLAMA_GRAPH_REUSE_DISABLE // env: LLAMA_GRAPH_REUSE_DISABLE
bool graph_reuse_disable = false; bool graph_reuse_disable = false;
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <cstdint> #include <cstdint>
#define LLAMA_MAX_SEQ 64 #define LLAMA_MAX_SEQ 256
struct llama_cparams { struct llama_cparams {
uint32_t n_ctx; // context size used during inference uint32_t n_ctx; // context size used during inference
...@@ -24,7 +24,6 @@ struct llama_cparams { ...@@ -24,7 +24,6 @@ struct llama_cparams {
float yarn_attn_factor; float yarn_attn_factor;
float yarn_beta_fast; float yarn_beta_fast;
float yarn_beta_slow; float yarn_beta_slow;
float defrag_thold;
bool embeddings; bool embeddings;
bool causal_attn; bool causal_attn;
......
This diff is collapsed.
...@@ -19,8 +19,8 @@ struct llama_cparams; ...@@ -19,8 +19,8 @@ struct llama_cparams;
struct llama_memory_context_i; struct llama_memory_context_i;
class llama_kv_cache_unified_context; class llama_kv_cache_context;
class llama_kv_cache_unified_iswa_context; class llama_kv_cache_iswa_context;
class llama_memory_recurrent_context; class llama_memory_recurrent_context;
class llama_memory_hybrid_context; class llama_memory_hybrid_context;
...@@ -78,6 +78,11 @@ struct llm_graph_params; ...@@ -78,6 +78,11 @@ struct llm_graph_params;
class llm_graph_input_i { class llm_graph_input_i {
public: public:
llm_graph_input_i() {
const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
}
virtual ~llm_graph_input_i() = default; virtual ~llm_graph_input_i() = default;
virtual void set_input(const llama_ubatch * ubatch) = 0; virtual void set_input(const llama_ubatch * ubatch) = 0;
...@@ -90,6 +95,9 @@ public: ...@@ -90,6 +95,9 @@ public:
GGML_UNUSED(params); GGML_UNUSED(params);
return false; return false;
} }
protected:
// env: LLAMA_GRAPH_INPUT_DEBUG
int debug = 0;
}; };
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>; using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
...@@ -152,7 +160,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { ...@@ -152,7 +160,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
public: public:
llm_graph_input_pos_bucket_kv( llm_graph_input_pos_bucket_kv(
const llama_hparams & hparams, const llama_hparams & hparams,
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {} const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
virtual ~llm_graph_input_pos_bucket_kv() = default; virtual ~llm_graph_input_pos_bucket_kv() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
...@@ -161,7 +169,7 @@ public: ...@@ -161,7 +169,7 @@ public:
const llama_hparams hparams; const llama_hparams hparams;
const llama_kv_cache_unified_context * mctx; const llama_kv_cache_context * mctx;
}; };
class llm_graph_input_out_ids : public llm_graph_input_i { class llm_graph_input_out_ids : public llm_graph_input_i {
...@@ -198,7 +206,7 @@ public: ...@@ -198,7 +206,7 @@ public:
class llm_graph_input_cls : public llm_graph_input_i { class llm_graph_input_cls : public llm_graph_input_i {
public: public:
llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {} llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
virtual ~llm_graph_input_cls() = default; virtual ~llm_graph_input_cls() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
...@@ -206,6 +214,7 @@ public: ...@@ -206,6 +214,7 @@ public:
ggml_tensor * cls; // I32 [n_batch] ggml_tensor * cls; // I32 [n_batch]
const llama_cparams cparams; const llama_cparams cparams;
const llm_arch arch;
}; };
class llm_graph_input_rs : public llm_graph_input_i { class llm_graph_input_rs : public llm_graph_input_i {
...@@ -257,17 +266,17 @@ public: ...@@ -257,17 +266,17 @@ public:
const llama_cparams cparams; const llama_cparams cparams;
}; };
class llm_graph_input_attn_kv_unified : public llm_graph_input_i { class llm_graph_input_attn_kv : public llm_graph_input_i {
public: public:
llm_graph_input_attn_kv_unified( llm_graph_input_attn_kv(
const llama_hparams & hparams, const llama_hparams & hparams,
const llama_cparams & cparams, const llama_cparams & cparams,
const llama_kv_cache_unified_context * mctx) : const llama_kv_cache_context * mctx) :
hparams(hparams), hparams(hparams),
cparams(cparams), cparams(cparams),
mctx(mctx) { mctx(mctx) {
} }
~llm_graph_input_attn_kv_unified() = default; ~llm_graph_input_attn_kv() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
...@@ -290,20 +299,20 @@ public: ...@@ -290,20 +299,20 @@ public:
const llama_hparams hparams; const llama_hparams hparams;
const llama_cparams cparams; const llama_cparams cparams;
const llama_kv_cache_unified_context * mctx; const llama_kv_cache_context * mctx;
}; };
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
public: public:
llm_graph_input_attn_kv_unified_iswa( llm_graph_input_attn_kv_iswa(
const llama_hparams & hparams, const llama_hparams & hparams,
const llama_cparams & cparams, const llama_cparams & cparams,
const llama_kv_cache_unified_iswa_context * mctx) : const llama_kv_cache_iswa_context * mctx) :
hparams(hparams), hparams(hparams),
cparams(cparams), cparams(cparams),
mctx(mctx) { mctx(mctx) {
} }
~llm_graph_input_attn_kv_unified_iswa() = default; ~llm_graph_input_attn_kv_iswa() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
...@@ -330,7 +339,7 @@ public: ...@@ -330,7 +339,7 @@ public:
const llama_hparams hparams; const llama_hparams hparams;
const llama_cparams cparams; const llama_cparams cparams;
const llama_kv_cache_unified_iswa_context * mctx; const llama_kv_cache_iswa_context * mctx;
}; };
class llm_graph_input_attn_cross : public llm_graph_input_i { class llm_graph_input_attn_cross : public llm_graph_input_i {
...@@ -351,7 +360,7 @@ public: ...@@ -351,7 +360,7 @@ public:
class llm_graph_input_mem_hybrid : public llm_graph_input_i { class llm_graph_input_mem_hybrid : public llm_graph_input_i {
public: public:
llm_graph_input_mem_hybrid( llm_graph_input_mem_hybrid(
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn, std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
std::unique_ptr<llm_graph_input_rs> inp_rs, std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_context * mctx) : const llama_memory_hybrid_context * mctx) :
inp_attn(std::move(inp_attn)), inp_attn(std::move(inp_attn)),
...@@ -361,11 +370,11 @@ public: ...@@ -361,11 +370,11 @@ public:
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn; std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs; std::unique_ptr<llm_graph_input_rs> inp_rs;
llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); } llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); } llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
const llama_memory_hybrid_context * mctx; const llama_memory_hybrid_context * mctx;
}; };
...@@ -680,14 +689,15 @@ struct llm_graph_context { ...@@ -680,14 +689,15 @@ struct llm_graph_context {
// //
ggml_tensor * build_attn_mha( ggml_tensor * build_attn_mha(
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false) ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
ggml_tensor * kq_b, ggml_tensor * kq_b,
ggml_tensor * kq_mask, ggml_tensor * kq_mask,
ggml_tensor * sinks, ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale) const; float kq_scale,
int il) const;
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const; llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
...@@ -699,50 +709,39 @@ struct llm_graph_context { ...@@ -699,50 +709,39 @@ struct llm_graph_context {
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b, ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale, float kq_scale,
int il) const; int il) const;
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const; llm_graph_input_attn_kv * build_attn_inp_kv() const;
ggml_tensor * build_attn( ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified * inp, llm_graph_input_attn_kv * inp,
ggml_tensor * wo, ggml_tensor * wo,
ggml_tensor * wo_b, ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b, ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale, float kq_scale,
int il) const; int il) const;
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const; llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
// note: if k_cur or v_cur are not provided, they will not be stored in the memory // note: if k_cur or v_cur are not provided, they will not be stored in the memory
ggml_tensor * build_attn( ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified_iswa * inp, llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
// TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
ggml_tensor * build_attn_with_sinks(
llm_graph_input_attn_kv_unified_iswa * inp,
ggml_tensor * wo, ggml_tensor * wo,
ggml_tensor * wo_b, ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
ggml_tensor * kq_b, ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
ggml_tensor * sinks, // [n_head_q] ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale, float kq_scale,
int il) const; int il) const;
...@@ -756,6 +755,7 @@ struct llm_graph_context { ...@@ -756,6 +755,7 @@ struct llm_graph_context {
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b, ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale, float kq_scale,
int il) const; int il) const;
...@@ -765,7 +765,7 @@ struct llm_graph_context { ...@@ -765,7 +765,7 @@ struct llm_graph_context {
// //
// TODO: move this implementation to llama_memory_recurrent. // TODO: move this implementation to llama_memory_recurrent.
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v // this is analogous to llama_kv_cache::cpy_k / cpy_v
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
// `llama_memory_recurrent` // `llama_memory_recurrent`
......
#include "llama-hparams.h" #include "llama-hparams.h"
#include "ggml.h" #include "ggml.h"
#include <cassert>
void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
if (dense_first) { if (dense_first) {
...@@ -161,3 +162,64 @@ bool llama_hparams::is_swa(uint32_t il) const { ...@@ -161,3 +162,64 @@ bool llama_hparams::is_swa(uint32_t il) const {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
bool llama_hparams::has_kv(uint32_t il) const {
if (n_layer_kv_from_start >= 0) {
if (il < (uint32_t) n_layer_kv_from_start) {
return true;
}
return false;
}
// by default, all layers have kv
return true;
}
uint32_t llama_hparams::n_layer_kv() const {
uint32_t res = 0;
for (uint32_t il = 0; il < n_layer; ++il) {
if (has_kv(il)) {
res++;
}
}
return res;
}
bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
assert(p0 >= 0 && p1 >= 0);
switch (swa_type) {
case LLAMA_SWA_TYPE_NONE:
{
} break;
case LLAMA_SWA_TYPE_STANDARD:
{
if (p1 - p0 >= (int32_t) n_swa) {
return true;
}
} break;
case LLAMA_SWA_TYPE_CHUNKED:
{
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
if (p0 < pos_chunk_start) {
return true;
}
} break;
case LLAMA_SWA_TYPE_SYMMETRIC:
{
const int32_t half_n_swa = (int32_t) n_swa / 2;
const int32_t pos_diff = p1 - p0;
// Mask if outside the symmetric window
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
return true;
}
} break;
}
return false;
}
...@@ -16,9 +16,10 @@ enum llama_expert_gating_func_type { ...@@ -16,9 +16,10 @@ enum llama_expert_gating_func_type {
}; };
enum llama_swa_type { enum llama_swa_type {
LLAMA_SWA_TYPE_NONE = 0, LLAMA_SWA_TYPE_NONE = 0,
LLAMA_SWA_TYPE_STANDARD = 1, LLAMA_SWA_TYPE_STANDARD = 1,
LLAMA_SWA_TYPE_CHUNKED = 2, LLAMA_SWA_TYPE_CHUNKED = 2,
LLAMA_SWA_TYPE_SYMMETRIC = 3,
}; };
struct llama_hparams_posnet { struct llama_hparams_posnet {
...@@ -41,6 +42,7 @@ struct llama_hparams { ...@@ -41,6 +42,7 @@ struct llama_hparams {
uint32_t n_embd; uint32_t n_embd;
uint32_t n_embd_features = 0; uint32_t n_embd_features = 0;
uint32_t n_layer; uint32_t n_layer;
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
uint32_t n_rot; uint32_t n_rot;
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
...@@ -69,10 +71,13 @@ struct llama_hparams { ...@@ -69,10 +71,13 @@ struct llama_hparams {
uint32_t n_lora_kv = 0; uint32_t n_lora_kv = 0;
uint32_t n_ff_exp = 0; uint32_t n_ff_exp = 0;
uint32_t n_ff_shexp = 0; uint32_t n_ff_shexp = 0;
uint32_t n_ff_chexp = 0;
uint32_t n_expert_shared = 0; uint32_t n_expert_shared = 0;
uint32_t n_norm_groups = 0; uint32_t n_norm_groups = 0;
uint32_t n_group_experts = 0;
float expert_weights_scale = 0.0; float expert_group_scale = 0.05f;
float expert_weights_scale = 0.0f;
bool expert_weights_norm = false; bool expert_weights_norm = false;
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
uint32_t moe_every_n_layers = 0; uint32_t moe_every_n_layers = 0;
...@@ -82,8 +87,9 @@ struct llama_hparams { ...@@ -82,8 +87,9 @@ struct llama_hparams {
float f_norm_rms_eps; float f_norm_rms_eps;
float f_norm_group_eps; float f_norm_group_eps;
float f_attn_logit_softcapping = 50.0f; float f_attn_logit_softcapping = 50.0f;
float f_final_logit_softcapping = 30.0f; float f_router_logit_softcapping = 30.0f;
float f_final_logit_softcapping = 30.0f;
// for RWKV // for RWKV
uint32_t rescale_every_n_layers = 0; uint32_t rescale_every_n_layers = 0;
...@@ -104,6 +110,11 @@ struct llama_hparams { ...@@ -104,6 +110,11 @@ struct llama_hparams {
uint32_t n_ctx_orig_yarn; uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul = 0.0f; float rope_yarn_log_mul = 0.0f;
float yarn_ext_factor = -1.0f;
float yarn_attn_factor = 1.0f;
float yarn_beta_fast = 32.0f;
float yarn_beta_slow = 1.0f;
std::array<int, 4> rope_sections; std::array<int, 4> rope_sections;
// Sliding Window Attention (SWA) // Sliding Window Attention (SWA)
...@@ -136,10 +147,14 @@ struct llama_hparams { ...@@ -136,10 +147,14 @@ struct llama_hparams {
float f_embedding_scale = 0.0f; float f_embedding_scale = 0.0f;
float f_attention_scale = 0.0f; float f_attention_scale = 0.0f;
// grok-2
float f_attn_out_scale = 0.0f;
uint32_t attn_temp_length = 0;
bool causal_attn = true; bool causal_attn = true;
bool use_alibi = false; bool use_alibi = false;
bool attn_soft_cap = false; bool attn_soft_cap = false;
bool use_kq_norm = true; bool use_kq_norm = false;
// for Classifiers // for Classifiers
uint32_t n_cls_out = 1; uint32_t n_cls_out = 1;
...@@ -159,6 +174,7 @@ struct llama_hparams { ...@@ -159,6 +174,7 @@ struct llama_hparams {
// needed by encoder-decoder models (e.g. T5, FLAN-T5) // needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141 // ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL; llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
uint32_t dec_n_layer = 0;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
...@@ -226,6 +242,16 @@ struct llama_hparams { ...@@ -226,6 +242,16 @@ struct llama_hparams {
bool n_bskcn(uint32_t n, uint32_t il) const; bool n_bskcn(uint32_t n, uint32_t il) const;
bool is_swa(uint32_t il) const; bool is_swa(uint32_t il) const;
bool has_kv(uint32_t il) const;
// number of layers for which has_kv() returns true
uint32_t n_layer_kv() const;
// note that this function uses different SWA parameters from those in the hparams
// TODO: think of a better place for this function
// TODO: pack the SWA params in a struct?
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
}; };
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable"); static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
......
...@@ -59,3 +59,5 @@ std::string llama_format_tensor_shape(const std::vector<int64_t> & ne); ...@@ -59,3 +59,5 @@ std::string llama_format_tensor_shape(const std::vector<int64_t> & ne);
std::string llama_format_tensor_shape(const struct ggml_tensor * t); std::string llama_format_tensor_shape(const struct ggml_tensor * t);
std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i); std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i);
#define LLAMA_TENSOR_NAME_FATTN "__fattn__"
#include "llama-kv-cache-unified-iswa.h" #include "llama-kv-cache-iswa.h"
#include "llama-impl.h" #include "llama-impl.h"
#include "llama-batch.h" #include "llama-batch.h"
...@@ -8,10 +8,10 @@ ...@@ -8,10 +8,10 @@
#include <cassert> #include <cassert>
// //
// llama_kv_cache_unified_iswa // llama_kv_cache_iswa
// //
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( llama_kv_cache_iswa::llama_kv_cache_iswa(
const llama_model & model, const llama_model & model,
ggml_type type_k, ggml_type type_k,
ggml_type type_v, ggml_type type_v,
...@@ -22,9 +22,26 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( ...@@ -22,9 +22,26 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
uint32_t kv_size, uint32_t kv_size,
uint32_t n_seq_max, uint32_t n_seq_max,
uint32_t n_ubatch, uint32_t n_ubatch,
uint32_t n_pad) : hparams(model.hparams), unified(unified) { uint32_t n_pad,
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; const layer_filter_cb & filter,
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
// chain filters
const layer_filter_cb filter_base = [&](int32_t il) {
if (filter && !filter(il)) {
return false;
}
return !model.hparams.is_swa(il);
};
const layer_filter_cb filter_swa = [&](int32_t il) {
if (filter && !filter(il)) {
return false;
}
return model.hparams.is_swa(il);
};
const uint32_t size_base = kv_size; const uint32_t size_base = kv_size;
...@@ -40,25 +57,25 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( ...@@ -40,25 +57,25 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
kv_base = std::make_unique<llama_kv_cache_unified>( kv_base = std::make_unique<llama_kv_cache>(
model, std::move(filter_base), type_k, type_v, model, type_k, type_v,
v_trans, offload, unified, size_base, n_seq_max, n_pad, v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE); 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache_unified>( kv_swa = std::make_unique<llama_kv_cache>(
model, std::move(filter_swa), type_k, type_v, model, type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad, v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type); hparams.n_swa, hparams.swa_type, filter_swa, reuse);
} }
void llama_kv_cache_unified_iswa::clear(bool data) { void llama_kv_cache_iswa::clear(bool data) {
kv_base->clear(data); kv_base->clear(data);
kv_swa ->clear(data); kv_swa ->clear(data);
} }
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
bool res = true; bool res = true;
res = res & kv_base->seq_rm(seq_id, p0, p1); res = res & kv_base->seq_rm(seq_id, p0, p1);
...@@ -67,36 +84,44 @@ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llam ...@@ -67,36 +84,44 @@ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llam
return res; return res;
} }
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1); kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1); kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
} }
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) { void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
kv_base->seq_keep(seq_id); kv_base->seq_keep(seq_id);
kv_swa ->seq_keep(seq_id); kv_swa ->seq_keep(seq_id);
} }
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
kv_base->seq_add(seq_id, p0, p1, shift); kv_base->seq_add(seq_id, p0, p1, shift);
kv_swa ->seq_add(seq_id, p0, p1, shift); kv_swa ->seq_add(seq_id, p0, p1, shift);
} }
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
kv_base->seq_div(seq_id, p0, p1, d); kv_base->seq_div(seq_id, p0, p1, d);
kv_swa ->seq_div(seq_id, p0, p1, d); kv_swa ->seq_div(seq_id, p0, p1, d);
} }
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const { llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
// the base cache is a superset of the SWA cache, so we can just check the SWA cache // the base cache is a superset of the SWA cache, so we can just check the SWA cache
return kv_swa->seq_pos_min(seq_id); return kv_swa->seq_pos_min(seq_id);
} }
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
return kv_swa->seq_pos_max(seq_id); return kv_swa->seq_pos_max(seq_id);
} }
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_iswa::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> mb = kv_base->memory_breakdown();
for (const auto & buft_size : kv_swa->memory_breakdown()) {
mb[buft_size.first] += buft_size.second;
}
return mb;
}
llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
GGML_UNUSED(embd_all); GGML_UNUSED(embd_all);
// first try simple split // first try simple split
...@@ -136,7 +161,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all ...@@ -136,7 +161,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
assert(sinfos_base.size() == sinfos_swa.size()); assert(sinfos_base.size() == sinfos_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_context>( return std::make_unique<llama_kv_cache_iswa_context>(
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false); } while (false);
...@@ -172,61 +197,67 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all ...@@ -172,61 +197,67 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
assert(sinfos_base.size() == sinfos_swa.size()); assert(sinfos_base.size() == sinfos_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_context>( return std::make_unique<llama_kv_cache_iswa_context>(
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false); } while (false);
// TODO: if we fail again, we should attempt different splitting strategies // TODO: if we fail again, we should attempt different splitting strategies
// but to do that properly, we first have to refactor the batches to be more flexible // but to do that properly, we first have to refactor the batches to be more flexible
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
} }
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() { llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
return std::make_unique<llama_kv_cache_unified_iswa_context>(this); return std::make_unique<llama_kv_cache_iswa_context>(this);
} }
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) { llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize); return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
} }
bool llama_kv_cache_unified_iswa::get_can_shift() const { bool llama_kv_cache_iswa::get_can_shift() const {
return kv_base->get_size() == kv_swa->get_size(); return kv_base->get_size() == kv_swa->get_size();
} }
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
kv_base->state_write(io, seq_id); if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
kv_swa ->state_write(io, seq_id); kv_base->state_write(io, seq_id, flags);
}
kv_swa->state_write(io, seq_id, flags);
} }
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) { void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
kv_base->state_read(io, seq_id); if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
kv_swa ->state_read(io, seq_id); kv_base->state_read(io, seq_id, flags);
}
kv_swa->state_read(io, seq_id, flags);
} }
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const { llama_kv_cache * llama_kv_cache_iswa::get_base() const {
return kv_base.get(); return kv_base.get();
} }
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
return kv_swa.get(); return kv_swa.get();
} }
// //
// llama_kv_cache_unified_iswa_context // llama_kv_cache_iswa_context
// //
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {} llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
llama_kv_cache_unified_iswa * kv) : llama_kv_cache_iswa * kv) :
ctx_base(kv->get_base()->init_full()), ctx_base(kv->get_base()->init_full()),
ctx_swa (kv->get_swa ()->init_full()), ctx_swa (kv->get_swa ()->init_full()),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
} }
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_iswa * kv,
llama_context * lctx, llama_context * lctx,
bool optimize) : bool optimize) :
ctx_base(kv->get_base()->init_update(lctx, optimize)), ctx_base(kv->get_base()->init_update(lctx, optimize)),
...@@ -234,21 +265,21 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( ...@@ -234,21 +265,21 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
} }
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_iswa * kv,
slot_info_vec_t sinfos_base, slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_swa, slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches) : std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)), ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal // note: here we copy the ubatches. not sure if this is ideal
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)), ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)), ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
} }
llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default; llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
bool llama_kv_cache_unified_iswa_context::next() { bool llama_kv_cache_iswa_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
ctx_base->next(); ctx_base->next();
...@@ -261,7 +292,7 @@ bool llama_kv_cache_unified_iswa_context::next() { ...@@ -261,7 +292,7 @@ bool llama_kv_cache_unified_iswa_context::next() {
return true; return true;
} }
bool llama_kv_cache_unified_iswa_context::apply() { bool llama_kv_cache_iswa_context::apply() {
assert(!llama_memory_status_is_fail(status)); assert(!llama_memory_status_is_fail(status));
bool res = true; bool res = true;
...@@ -272,24 +303,24 @@ bool llama_kv_cache_unified_iswa_context::apply() { ...@@ -272,24 +303,24 @@ bool llama_kv_cache_unified_iswa_context::apply() {
return res; return res;
} }
llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const { llama_memory_status llama_kv_cache_iswa_context::get_status() const {
return status; return status;
} }
const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const { const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next]; return ubatches[i_next];
} }
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const { const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get()); return static_cast<const llama_kv_cache_context *>(ctx_base.get());
} }
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const { const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get()); return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
} }
#pragma once #pragma once
#include "llama-kv-cache-unified.h" #include "llama-kv-cache.h"
#include <vector> #include <vector>
// //
// llama_kv_cache_unified_iswa // llama_kv_cache_iswa
// //
// utilizes two instances of llama_kv_cache_unified // utilizes two instances of llama_kv_cache
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
class llama_kv_cache_unified_iswa : public llama_memory_i { class llama_kv_cache_iswa : public llama_memory_i {
public: public:
llama_kv_cache_unified_iswa( llama_kv_cache_iswa(
const llama_model & model, const llama_model & model,
ggml_type type_k, ggml_type type_k,
ggml_type type_v, ggml_type type_v,
...@@ -24,9 +24,11 @@ public: ...@@ -24,9 +24,11 @@ public:
uint32_t kv_size, uint32_t kv_size,
uint32_t n_seq_max, uint32_t n_seq_max,
uint32_t n_ubatch, uint32_t n_ubatch,
uint32_t n_pad); uint32_t n_pad,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
~llama_kv_cache_unified_iswa() = default; ~llama_kv_cache_iswa() = default;
// //
// llama_memory_i // llama_memory_i
...@@ -54,52 +56,54 @@ public: ...@@ -54,52 +56,54 @@ public:
llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
// state write/load // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
// //
// llama_kv_cache_unified_iswa specific API // llama_kv_cache_iswa specific API
// //
llama_kv_cache_unified * get_base() const; llama_kv_cache * get_base() const;
llama_kv_cache_unified * get_swa () const; llama_kv_cache * get_swa () const;
private: private:
const llama_hparams & hparams; const llama_hparams & hparams;
const bool unified; const bool unified;
std::unique_ptr<llama_kv_cache_unified> kv_base; std::unique_ptr<llama_kv_cache> kv_base;
std::unique_ptr<llama_kv_cache_unified> kv_swa; std::unique_ptr<llama_kv_cache> kv_swa;
}; };
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i { class llama_kv_cache_iswa_context : public llama_memory_context_i {
public: public:
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
// used for errors // used for errors
llama_kv_cache_unified_iswa_context(llama_memory_status status); llama_kv_cache_iswa_context(llama_memory_status status);
// used to create a full-cache context // used to create a full-cache context
llama_kv_cache_unified_iswa_context( llama_kv_cache_iswa_context(
llama_kv_cache_unified_iswa * kv); llama_kv_cache_iswa * kv);
// used to create an update context // used to create an update context
llama_kv_cache_unified_iswa_context( llama_kv_cache_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_iswa * kv,
llama_context * lctx, llama_context * lctx,
bool optimize); bool optimize);
// used to create a batch processing context from a batch // used to create a batch processing context from a batch
llama_kv_cache_unified_iswa_context( llama_kv_cache_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_iswa * kv,
slot_info_vec_t sinfos_base, slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_swa, slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_iswa_context(); virtual ~llama_kv_cache_iswa_context();
// //
// llama_memory_context_i // llama_memory_context_i
...@@ -112,14 +116,14 @@ public: ...@@ -112,14 +116,14 @@ public:
const llama_ubatch & get_ubatch() const override; const llama_ubatch & get_ubatch() const override;
// //
// llama_kv_cache_unified_iswa_context specific API // llama_kv_cache_iswa_context specific API
// //
const llama_kv_cache_unified_context * get_base() const; const llama_kv_cache_context * get_base() const;
const llama_kv_cache_unified_context * get_swa() const; const llama_kv_cache_context * get_swa() const;
private: private:
//llama_kv_cache_unified_iswa * kv; //llama_kv_cache_iswa * kv;
// the index of the next ubatch to process // the index of the next ubatch to process
size_t i_next = 0; size_t i_next = 0;
......
...@@ -14,27 +14,13 @@ struct llama_model; ...@@ -14,27 +14,13 @@ struct llama_model;
struct llama_context; struct llama_context;
// //
// llama_kv_cache_unified // llama_kv_cache
// //
class llama_kv_cache_unified : public llama_memory_i { class llama_kv_cache : public llama_memory_i {
public: public:
static uint32_t get_padding(const llama_cparams & cparams); static uint32_t get_padding(const llama_cparams & cparams);
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
struct defrag_info {
bool empty() const {
return ids.empty();
}
// contains information about which cell moves where:
// - cell i moves to ids[i]
// - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
std::vector<uint32_t> ids;
};
struct stream_copy_info { struct stream_copy_info {
bool empty() const { bool empty() const {
assert(ssrc.size() == sdst.size()); assert(ssrc.size() == sdst.size());
...@@ -52,8 +38,8 @@ public: ...@@ -52,8 +38,8 @@ public:
using idx_vec_t = std::vector<uint32_t>; using idx_vec_t = std::vector<uint32_t>;
// number of streams: ns = s1 - s0 + 1 // number of streams: ns = s1 - s0 + 1
llama_seq_id s0; uint32_t s0;
llama_seq_id s1; uint32_t s1;
std::vector<llama_seq_id> strm; // [ns] std::vector<llama_seq_id> strm; // [ns]
std::vector<idx_vec_t> idxs; // [ns] std::vector<idx_vec_t> idxs; // [ns]
...@@ -92,21 +78,22 @@ public: ...@@ -92,21 +78,22 @@ public:
using slot_info_vec_t = std::vector<slot_info>; using slot_info_vec_t = std::vector<slot_info>;
llama_kv_cache_unified( llama_kv_cache(
const llama_model & model, const llama_model & model,
layer_filter_cb && filter, ggml_type type_k,
ggml_type type_k, ggml_type type_v,
ggml_type type_v, bool v_trans,
bool v_trans, bool offload,
bool offload, bool unified,
bool unified, uint32_t kv_size,
uint32_t kv_size, uint32_t n_seq_max,
uint32_t n_seq_max, uint32_t n_pad,
uint32_t n_pad, uint32_t n_swa,
uint32_t n_swa, llama_swa_type swa_type,
llama_swa_type swa_type); const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
~llama_kv_cache_unified() = default;
~llama_kv_cache() = default;
// //
// llama_memory_i // llama_memory_i
...@@ -134,13 +121,15 @@ public: ...@@ -134,13 +121,15 @@ public:
llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
// state write/load // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
// //
// llama_kv_cache_unified specific API // llama_kv_cache specific API
// //
uint32_t get_size() const; uint32_t get_size() const;
...@@ -152,10 +141,7 @@ public: ...@@ -152,10 +141,7 @@ public:
// graph_build API // graph_build API
// //
uint32_t get_n_kv() const; uint32_t get_n_kv(const slot_info & sinfo) const;
// TODO: temporary
bool get_supports_set_rows() const;
// get views of the current state of the cache // get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
...@@ -173,7 +159,7 @@ public: ...@@ -173,7 +159,7 @@ public:
// return empty vector on failure // return empty vector on failure
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches); slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info); bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info);
// find a slot of kv cells that can hold the ubatch // find a slot of kv cells that can hold the ubatch
// if cont == true, then the slot must be continuous // if cont == true, then the slot must be continuous
...@@ -228,10 +214,7 @@ private: ...@@ -228,10 +214,7 @@ private:
// env: LLAMA_KV_CACHE_DEBUG // env: LLAMA_KV_CACHE_DEBUG
int debug = 0; int debug = 0;
// env: LLAMA_SET_ROWS (temporary) // this is the SWA type of the cache - not to be confused with the model SWA type
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
bool supports_set_rows = true;
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
...@@ -241,7 +224,7 @@ private: ...@@ -241,7 +224,7 @@ private:
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
std::vector<uint32_t> v_heads; std::vector<uint32_t> v_heads;
std::vector<llama_kv_cells_unified> v_cells; std::vector<llama_kv_cells> v_cells;
// maps from a sequence id to a stream id // maps from a sequence id to a stream id
std::vector<uint32_t> seq_to_stream; std::vector<uint32_t> seq_to_stream;
...@@ -254,9 +237,6 @@ private: ...@@ -254,9 +237,6 @@ private:
// model layer id -> KV cache layer id // model layer id -> KV cache layer id
std::unordered_map<int32_t, int32_t> map_layer_ids; std::unordered_map<int32_t, int32_t> map_layer_ids;
// return non-empty vector if cells have been moved
defrag_info defrag_prepare(int32_t n_max_nodes) const;
size_t total_size() const; size_t total_size() const;
size_t size_k_bytes() const; size_t size_k_bytes() const;
...@@ -277,11 +257,6 @@ private: ...@@ -277,11 +257,6 @@ private:
llm_graph_result * res, llm_graph_result * res,
llama_context * lctx) const; llama_context * lctx) const;
ggml_cgraph * build_graph_defrag(
llm_graph_result * res,
llama_context * lctx,
const defrag_info & dinfo) const;
struct cell_ranges_t { struct cell_ranges_t {
uint32_t strm; uint32_t strm;
...@@ -295,35 +270,33 @@ private: ...@@ -295,35 +270,33 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count); bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
}; };
class llama_kv_cache_unified_context : public llama_memory_context_i { class llama_kv_cache_context : public llama_memory_context_i {
public: public:
// some shorthands // some shorthands
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
using defrag_info = llama_kv_cache_unified::defrag_info; using stream_copy_info = llama_kv_cache::stream_copy_info;
using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
// used for errors // used for errors
llama_kv_cache_unified_context(llama_memory_status status); llama_kv_cache_context(llama_memory_status status);
// used to create a full-cache context // used to create a full-cache context
llama_kv_cache_unified_context( llama_kv_cache_context(
llama_kv_cache_unified * kv); llama_kv_cache * kv);
// used to create an update context // used to create an update context
llama_kv_cache_unified_context( llama_kv_cache_context(
llama_kv_cache_unified * kv, llama_kv_cache * kv,
llama_context * lctx, llama_context * lctx,
bool do_shift, bool do_shift,
defrag_info dinfo,
stream_copy_info sc_info); stream_copy_info sc_info);
// used to create a batch procesing context from a batch // used to create a batch procesing context from a batch
llama_kv_cache_unified_context( llama_kv_cache_context(
llama_kv_cache_unified * kv, llama_kv_cache * kv,
slot_info_vec_t sinfos, slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_context(); virtual ~llama_kv_cache_context();
// //
// llama_memory_context_i // llama_memory_context_i
...@@ -336,22 +309,27 @@ public: ...@@ -336,22 +309,27 @@ public:
const llama_ubatch & get_ubatch() const override; const llama_ubatch & get_ubatch() const override;
// //
// llama_kv_cache_unified_context specific API // llama_kv_cache_context specific API
// //
uint32_t get_n_kv() const; uint32_t get_n_kv() const;
// TODO: temporary
bool get_supports_set_rows() const;
// get views of the current state of the cache // get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
// store k_cur and v_cur in the cache based on the provided head location // store k_cur and v_cur in the cache based on the provided head location
// note: the heads in k_cur and v_cur should be layed out contiguously in memory
// - k_cur [n_embd_head_k, n_head_k, n_tokens]
// - k_idxs [n_tokens]
// - v_cur [n_embd_head_v, n_head_v, n_tokens]
// - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
// create destination indices for each head of the current batch for where it would be written in the KV cache
// the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but
// helps understand the implementation logic of cpy_k and cpy_v
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
...@@ -365,7 +343,7 @@ public: ...@@ -365,7 +343,7 @@ public:
private: private:
llama_memory_status status; llama_memory_status status;
llama_kv_cache_unified * kv; llama_kv_cache * kv;
llama_context * lctx; llama_context * lctx;
// //
...@@ -374,8 +352,6 @@ private: ...@@ -374,8 +352,6 @@ private:
bool do_shift = false; bool do_shift = false;
defrag_info dinfo;
stream_copy_info sc_info; stream_copy_info sc_info;
// //
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
// meta information about KV cells that can be part of multiple sequences at the same time // meta information about KV cells that can be part of multiple sequences at the same time
// TODO: add unit tests // TODO: add unit tests
class llama_kv_cells_unified { class llama_kv_cells {
public: public:
void reset() { void reset() {
for (uint32_t i = 0; i < pos.size(); ++i) { for (uint32_t i = 0; i < pos.size(); ++i) {
...@@ -77,30 +77,30 @@ public: ...@@ -77,30 +77,30 @@ public:
} }
// move cell isrc to idst (used during defrag) // move cell isrc to idst (used during defrag)
void mv(uint32_t isrc, uint32_t idst) { //void mv(uint32_t isrc, uint32_t idst) {
assert(isrc < pos.size()); // assert(isrc < pos.size());
assert(idst < pos.size()); // assert(idst < pos.size());
assert(pos[idst] == -1); // assert(pos[idst] == -1);
assert(pos[isrc] != -1); // assert(pos[isrc] != -1);
pos [idst] = pos [isrc]; // pos [idst] = pos [isrc];
shift[idst] = shift[isrc]; // shift[idst] = shift[isrc];
seq [idst] = seq [isrc]; // seq [idst] = seq [isrc];
pos [isrc] = -1; // pos [isrc] = -1;
shift[isrc] = 0; // shift[isrc] = 0;
seq [isrc].reset(); // seq [isrc].reset();
used.erase (isrc); // used.erase (isrc);
used.insert(idst); // used.insert(idst);
} //}
// copy the state of cells [i, i + n) (used for save/restore the state of the cells) // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
llama_kv_cells_unified cp(uint32_t i, uint32_t n) const { llama_kv_cells cp(uint32_t i, uint32_t n) const {
assert(i + n <= pos.size()); assert(i + n <= pos.size());
llama_kv_cells_unified res; llama_kv_cells res;
res.resize(n); res.resize(n);
...@@ -117,8 +117,8 @@ public: ...@@ -117,8 +117,8 @@ public:
} }
// copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1]) // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const { llama_kv_cells cp(const std::vector<uint32_t> & idxs) const {
llama_kv_cells_unified res; llama_kv_cells res;
res.resize(idxs.size()); res.resize(idxs.size());
...@@ -135,7 +135,7 @@ public: ...@@ -135,7 +135,7 @@ public:
} }
// set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells) // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
void set(uint32_t i, const llama_kv_cells_unified & other) { void set(uint32_t i, const llama_kv_cells & other) {
assert(i + other.pos.size() <= pos.size()); assert(i + other.pos.size() <= pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) { for (uint32_t j = 0; j < other.pos.size(); ++j) {
...@@ -165,7 +165,7 @@ public: ...@@ -165,7 +165,7 @@ public:
} }
// set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1]) // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) { void set(const std::vector<uint32_t> & idxs, const llama_kv_cells & other) {
assert(idxs.size() == other.pos.size()); assert(idxs.size() == other.pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) { for (uint32_t j = 0; j < other.pos.size(); ++j) {
......
...@@ -9,32 +9,29 @@ ...@@ -9,32 +9,29 @@
// //
llama_memory_hybrid::llama_memory_hybrid( llama_memory_hybrid::llama_memory_hybrid(
const llama_model & model, const llama_model & model,
/* attn */ /* attn */
ggml_type type_k, ggml_type type_k,
ggml_type type_v, ggml_type type_v,
bool v_trans, bool v_trans,
uint32_t kv_size, uint32_t kv_size,
uint32_t n_pad, uint32_t n_pad,
uint32_t n_swa, uint32_t n_swa,
llama_swa_type swa_type, llama_swa_type swa_type,
/* recurrent */ /* recurrent */
ggml_type type_r, ggml_type type_r,
ggml_type type_s, ggml_type type_s,
uint32_t rs_size, uint32_t rs_size,
/* common */ /* common */
uint32_t n_seq_max, uint32_t n_seq_max,
bool offload, bool offload,
bool unified, bool unified,
/* layer filters */ /* layer filters */
layer_filter_cb && filter_attn, const layer_filter_cb & filter_attn,
layer_filter_cb && filter_recr) : const layer_filter_cb & filter_recr) :
hparams(model.hparams), hparams(model.hparams),
mem_attn(new llama_kv_cache_unified( mem_attn(new llama_kv_cache(
model, model,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
type_k, type_k,
type_v, type_v,
v_trans, v_trans,
...@@ -44,18 +41,22 @@ llama_memory_hybrid::llama_memory_hybrid( ...@@ -44,18 +41,22 @@ llama_memory_hybrid::llama_memory_hybrid(
n_seq_max, n_seq_max,
n_pad, n_pad,
n_swa, n_swa,
swa_type swa_type,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
nullptr
)), )),
mem_recr(new llama_memory_recurrent( mem_recr(new llama_memory_recurrent(
model, model,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr,
type_r, type_r,
type_s, type_s,
offload, offload,
rs_size, rs_size,
n_seq_max n_seq_max,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr
)) {} )) {}
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
...@@ -165,17 +166,29 @@ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const { ...@@ -165,17 +166,29 @@ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id)); return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
} }
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> mb = mem_attn->memory_breakdown();
for (const auto & buft_size : mem_recr->memory_breakdown()) {
mb[buft_size.first] += buft_size.second;
}
return mb;
}
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);
mem_attn->state_write(io, seq_id); mem_attn->state_write(io, seq_id);
mem_recr->state_write(io, seq_id); mem_recr->state_write(io, seq_id);
} }
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) { void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);
mem_attn->state_read(io, seq_id); mem_attn->state_read(io, seq_id);
mem_recr->state_read(io, seq_id); mem_recr->state_read(io, seq_id);
} }
llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const { llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
return mem_attn.get(); return mem_attn.get();
} }
...@@ -206,7 +219,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context( ...@@ -206,7 +219,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
std::vector<llama_ubatch> ubatches) : std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)), ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal // note: here we copy the ubatches. not sure if this is ideal
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
} }
...@@ -244,8 +257,8 @@ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const { ...@@ -244,8 +257,8 @@ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
return ubatches[i_next]; return ubatches[i_next];
} }
const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const { const llama_kv_cache_context * llama_memory_hybrid_context::get_attn() const {
return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get()); return static_cast<const llama_kv_cache_context *>(ctx_attn.get());
} }
const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const { const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "llama-batch.h" #include "llama-batch.h"
#include "llama-graph.h" #include "llama-graph.h"
#include "llama-kv-cache-unified.h" #include "llama-kv-cache.h"
#include "llama-memory.h" #include "llama-memory.h"
#include "llama-memory-recurrent.h" #include "llama-memory-recurrent.h"
...@@ -13,36 +13,32 @@ ...@@ -13,36 +13,32 @@
// llama_memory_hybrid // llama_memory_hybrid
// //
// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to // utilizes instances of llama_memory_recurrent and llama_kv_cache to
// support models where each layer may be either attention-based or recurrent // support models where each layer may be either attention-based or recurrent
class llama_memory_hybrid : public llama_memory_i { class llama_memory_hybrid : public llama_memory_i {
public: public:
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
llama_memory_hybrid( llama_memory_hybrid(
const llama_model & model, const llama_model & model,
/* attn */ /* attn */
ggml_type type_k, ggml_type type_k,
ggml_type type_v, ggml_type type_v,
bool v_trans, bool v_trans,
uint32_t kv_size, uint32_t kv_size,
uint32_t n_pad, uint32_t n_pad,
uint32_t n_swa, uint32_t n_swa,
llama_swa_type swa_type, llama_swa_type swa_type,
/* recurrent */ /* recurrent */
ggml_type type_r, ggml_type type_r,
ggml_type type_s, ggml_type type_s,
uint32_t rs_size, uint32_t rs_size,
/* common */ /* common */
uint32_t n_seq_max, uint32_t n_seq_max,
bool offload, bool offload,
bool unified, bool unified,
/* layer filters */ /* layer filters */
layer_filter_cb && filter_attn = nullptr, const layer_filter_cb & filter_attn = nullptr,
layer_filter_cb && filter_recr = nullptr); const layer_filter_cb & filter_recr = nullptr);
~llama_memory_hybrid() = default; ~llama_memory_hybrid() = default;
...@@ -72,28 +68,30 @@ public: ...@@ -72,28 +68,30 @@ public:
llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
// state write/load // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
// //
// llama_memory_hybrid specific API // llama_memory_hybrid specific API
// //
llama_kv_cache_unified * get_mem_attn() const; llama_kv_cache * get_mem_attn() const;
llama_memory_recurrent * get_mem_recr() const; llama_memory_recurrent * get_mem_recr() const;
private: private:
const llama_hparams & hparams; const llama_hparams & hparams;
const std::unique_ptr<llama_kv_cache_unified> mem_attn; const std::unique_ptr<llama_kv_cache> mem_attn;
const std::unique_ptr<llama_memory_recurrent> mem_recr; const std::unique_ptr<llama_memory_recurrent> mem_recr;
}; };
class llama_memory_hybrid_context : public llama_memory_context_i { class llama_memory_hybrid_context : public llama_memory_context_i {
public: public:
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
// init failure // init failure
explicit llama_memory_hybrid_context(llama_memory_status status); explicit llama_memory_hybrid_context(llama_memory_status status);
...@@ -125,7 +123,7 @@ public: ...@@ -125,7 +123,7 @@ public:
// llama_memory_hybrid_context // llama_memory_hybrid_context
// //
const llama_kv_cache_unified_context * get_attn() const; const llama_kv_cache_context * get_attn() const;
const llama_memory_recurrent_context * get_recr() const; const llama_memory_recurrent_context * get_recr() const;
private: private:
......
...@@ -16,13 +16,13 @@ ...@@ -16,13 +16,13 @@
// //
llama_memory_recurrent::llama_memory_recurrent( llama_memory_recurrent::llama_memory_recurrent(
const llama_model & model, const llama_model & model,
layer_filter_cb && filter, ggml_type type_r,
ggml_type type_r, ggml_type type_s,
ggml_type type_s, bool offload,
bool offload, uint32_t mem_size,
uint32_t mem_size, uint32_t n_seq_max,
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
const int32_t n_layer = hparams.n_layer; const int32_t n_layer = hparams.n_layer;
head = 0; head = 0;
...@@ -359,6 +359,14 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { ...@@ -359,6 +359,14 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
return result; return result;
} }
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
}
return ret;
}
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
do { do {
balloc.split_reset(); balloc.split_reset();
...@@ -680,7 +688,9 @@ size_t llama_memory_recurrent::size_s_bytes() const { ...@@ -680,7 +688,9 @@ size_t llama_memory_recurrent::size_s_bytes() const {
return size_s_bytes; return size_s_bytes;
} }
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0; uint32_t cell_count = 0;
...@@ -718,7 +728,9 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq ...@@ -718,7 +728,9 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
state_write_data(io, cell_ranges); state_write_data(io, cell_ranges);
} }
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);
uint32_t cell_count; uint32_t cell_count;
io.read_to(&cell_count, sizeof(cell_count)); io.read_to(&cell_count, sizeof(cell_count));
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "llama-graph.h" #include "llama-graph.h"
#include "llama-memory.h" #include "llama-memory.h"
#include <map>
#include <set> #include <set>
#include <vector> #include <vector>
...@@ -12,21 +13,17 @@ ...@@ -12,21 +13,17 @@
// //
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i // TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
// see the implementation of llama_kv_cache_unified_context_i for an example how to do it // see the implementation of llama_kv_cache_context_i for an example how to do it
class llama_memory_recurrent : public llama_memory_i { class llama_memory_recurrent : public llama_memory_i {
public: public:
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
llama_memory_recurrent( llama_memory_recurrent(
const llama_model & model, const llama_model & model,
layer_filter_cb && filter, ggml_type type_r,
ggml_type type_r, ggml_type type_s,
ggml_type type_s, bool offload,
bool offload, uint32_t mem_size,
uint32_t mem_size, uint32_t n_seq_max,
uint32_t n_seq_max); const layer_filter_cb & filter);
~llama_memory_recurrent() = default; ~llama_memory_recurrent() = default;
...@@ -54,6 +51,8 @@ public: ...@@ -54,6 +51,8 @@ public:
llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
bool prepare(const std::vector<llama_ubatch> & ubatches); bool prepare(const std::vector<llama_ubatch> & ubatches);
// find a contiguous slot of memory cells and emplace the ubatch there // find a contiguous slot of memory cells and emplace the ubatch there
...@@ -63,8 +62,8 @@ public: ...@@ -63,8 +62,8 @@ public:
// state write/load // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences uint32_t size = 0; // total number of cells, shared across all sequences
......
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
#include "llama.h" #include "llama.h"
#include <map>
#include <memory> #include <memory>
#include <functional>
struct llama_ubatch; struct llama_ubatch;
...@@ -36,8 +38,8 @@ bool llama_memory_status_is_fail(llama_memory_status status); ...@@ -36,8 +38,8 @@ bool llama_memory_status_is_fail(llama_memory_status status);
// the interface for managing the memory context during batch processing // the interface for managing the memory context during batch processing
// this interface is implemented per memory type. see: // this interface is implemented per memory type. see:
// - llama_kv_cache_unified_context // - llama_kv_cache_context
// - llama_kv_cache_unified_iswa_context // - llama_kv_cache_iswa_context
// ... // ...
// //
// the only method that should mutate the memory and the memory context is llama_memory_i::apply() // the only method that should mutate the memory and the memory context is llama_memory_i::apply()
...@@ -64,6 +66,13 @@ using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>; ...@@ -64,6 +66,13 @@ using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
// general concept of LLM memory // general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types // the KV cache is a type of LLM memory, but there can be other types
struct llama_memory_i { struct llama_memory_i {
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
// this callback is used to specify which layers should reuse memory from other layers
// return negative value to indicate that the layer il should not reuse memory
using layer_reuse_cb = std::function<int32_t(int32_t il)>;
virtual ~llama_memory_i() = default; virtual ~llama_memory_i() = default;
// split the input batch into a set of ubatches and verify that they can fit into the cache // split the input batch into a set of ubatches and verify that they can fit into the cache
...@@ -77,7 +86,7 @@ struct llama_memory_i { ...@@ -77,7 +86,7 @@ struct llama_memory_i {
// simulate full cache, used for allocating worst-case compute buffers // simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_context_ptr init_full() = 0; virtual llama_memory_context_ptr init_full() = 0;
// prepare for any pending memory updates, such as shifts, defrags, etc. // prepare for any pending memory updates, such as shifts, copies, etc.
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0; virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
...@@ -100,17 +109,14 @@ struct llama_memory_i { ...@@ -100,17 +109,14 @@ struct llama_memory_i {
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0; virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
virtual std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const = 0;
// //
// state write/read // state write/read
// //
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
}; };
using llama_memory_ptr = std::unique_ptr<llama_memory_i>; using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
// TODO: temporary until the llama_kv_cache is removed from the public API
struct llama_kv_cache : public llama_memory_i {
virtual ~llama_kv_cache() = default;
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment