Unverified Commit 0cefd46f authored by Jeffrey Morgan's avatar Jeffrey Morgan Committed by GitHub
Browse files

llama: update to commit de4c07f93 (#10655)

parent ad035ad5
UPSTREAM=https://github.com/ggerganov/llama.cpp.git UPSTREAM=https://github.com/ggerganov/llama.cpp.git
WORKDIR=llama/vendor WORKDIR=llama/vendor
FETCH_HEAD=e1e8e0991ffd9e99a445c6812bb519d5bac9f4b5 FETCH_HEAD=de4c07f93783a1a96456a44dc16b9db538ee1618
.PHONY: help .PHONY: help
help: help:
......
int LLAMA_BUILD_NUMBER = 0; int LLAMA_BUILD_NUMBER = 0;
char const *LLAMA_COMMIT = "e1e8e0991ffd9e99a445c6812bb519d5bac9f4b5"; char const *LLAMA_COMMIT = "de4c07f93783a1a96456a44dc16b9db538ee1618";
char const *LLAMA_COMPILER = ""; char const *LLAMA_COMPILER = "";
char const *LLAMA_BUILD_TARGET = ""; char const *LLAMA_BUILD_TARGET = "";
...@@ -10,11 +10,11 @@ include common/stb_image.* ...@@ -10,11 +10,11 @@ include common/stb_image.*
include include/ include include/
include include/llama.* include include/llama.*
include include/llama-*.* include include/llama-*.*
include examples/ include tools/
include examples/llava/ include tools/mtmd/
include examples/llava/clip.* include tools/mtmd/clip.*
include examples/llava/clip-impl.* include tools/mtmd/clip-impl.*
include examples/llava/llava.* include tools/mtmd/llava.*
include src/ include src/
include src/llama.* include src/llama.*
include src/llama-*.* include src/llama-*.*
......
...@@ -1096,7 +1096,6 @@ struct llama_context_params common_context_params_to_llama(const common_params & ...@@ -1096,7 +1096,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.n_threads = params.cpuparams.n_threads; cparams.n_threads = params.cpuparams.n_threads;
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ? cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
params.cpuparams.n_threads : params.cpuparams_batch.n_threads; params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
cparams.logits_all = params.logits_all;
cparams.embeddings = params.embedding; cparams.embeddings = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type; cparams.rope_scaling_type = params.rope_scaling_type;
cparams.rope_freq_base = params.rope_freq_base; cparams.rope_freq_base = params.rope_freq_base;
...@@ -1114,6 +1113,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & ...@@ -1114,6 +1113,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.offload_kqv = !params.no_kv_offload; cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn; cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf; cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
if (params.reranking) { if (params.reranking) {
cparams.embeddings = true; cparams.embeddings = true;
...@@ -1565,3 +1565,20 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c ...@@ -1565,3 +1565,20 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
return result; return result;
} }
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) {
const int64_t ne_datapoint = llama_n_ctx(ctx);
const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride;
ggml_opt_dataset_t result = ggml_opt_dataset_init(
GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1);
llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data;
llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data;
for (int64_t idata = 0; idata < ndata; ++idata) {
memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token));
memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token));
}
return result;
}
...@@ -66,7 +66,6 @@ enum llama_example { ...@@ -66,7 +66,6 @@ enum llama_example {
LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_COMMON,
LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SPECULATIVE,
LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MAIN,
LLAMA_EXAMPLE_INFILL,
LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_EMBEDDING,
LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_PERPLEXITY,
LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_RETRIEVAL,
...@@ -96,6 +95,7 @@ enum common_sampler_type { ...@@ -96,6 +95,7 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_XTC = 8, COMMON_SAMPLER_TYPE_XTC = 8,
COMMON_SAMPLER_TYPE_INFILL = 9, COMMON_SAMPLER_TYPE_INFILL = 9,
COMMON_SAMPLER_TYPE_PENALTIES = 10, COMMON_SAMPLER_TYPE_PENALTIES = 10,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
}; };
// dimensionality reduction methods, used by cvector-generator // dimensionality reduction methods, used by cvector-generator
...@@ -161,6 +161,7 @@ struct common_params_sampling { ...@@ -161,6 +161,7 @@ struct common_params_sampling {
std::vector<enum common_sampler_type> samplers = { std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_PENALTIES, COMMON_SAMPLER_TYPE_PENALTIES,
COMMON_SAMPLER_TYPE_DRY, COMMON_SAMPLER_TYPE_DRY,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
COMMON_SAMPLER_TYPE_TOP_K, COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TYPICAL_P, COMMON_SAMPLER_TYPE_TYPICAL_P,
COMMON_SAMPLER_TYPE_TOP_P, COMMON_SAMPLER_TYPE_TOP_P,
...@@ -323,7 +324,6 @@ struct common_params { ...@@ -323,7 +324,6 @@ struct common_params {
bool ctx_shift = true; // context shift on inifinite text generation bool ctx_shift = true; // context shift on inifinite text generation
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool logits_all = false; // return logits for all tokens in the batch
bool use_mmap = true; // use mmap for faster loads bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory bool use_mlock = false; // use mlock to keep model in memory
bool verbose_prompt = false; // print prompt tokens before generation bool verbose_prompt = false; // print prompt tokens before generation
...@@ -332,6 +332,7 @@ struct common_params { ...@@ -332,6 +332,7 @@ struct common_params {
bool no_kv_offload = false; // disable KV offloading bool no_kv_offload = false; // disable KV offloading
bool warmup = true; // warmup run bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data bool check_tensors = false; // validate tensor data
bool no_op_offload = false; // globally disable offload host tensor operations to device
bool single_turn = false; // single turn chat conversation bool single_turn = false; // single turn chat conversation
...@@ -340,7 +341,7 @@ struct common_params { ...@@ -340,7 +341,7 @@ struct common_params {
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO; common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
// multimodal models (see examples/llava) // multimodal models (see tools/mtmd)
struct common_params_model mmproj; struct common_params_model mmproj;
bool mmproj_use_gpu = true; // use GPU for multimodal model bool mmproj_use_gpu = true; // use GPU for multimodal model
bool no_mmproj = false; // explicitly disable multimodal model bool no_mmproj = false; // explicitly disable multimodal model
...@@ -409,13 +410,14 @@ struct common_params { ...@@ -409,13 +410,14 @@ struct common_params {
bool process_output = false; // collect data for the output tensor bool process_output = false; // collect data for the output tensor
bool compute_ppl = true; // whether to compute perplexity bool compute_ppl = true; // whether to compute perplexity
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
// cvector-generator params // cvector-generator params
int n_pca_batch = 100; int n_pca_batch = 100;
int n_pca_iterations = 1000; int n_pca_iterations = 1000;
dimre_method cvector_dimre_method = DIMRE_METHOD_PCA; dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
std::string cvector_positive_file = "examples/cvector-generator/positive.txt"; std::string cvector_positive_file = "tools/cvector-generator/positive.txt";
std::string cvector_negative_file = "examples/cvector-generator/negative.txt"; std::string cvector_negative_file = "tools/cvector-generator/negative.txt";
bool spm_infill = false; // suffix/prefix/middle pattern for infill bool spm_infill = false; // suffix/prefix/middle pattern for infill
...@@ -664,3 +666,9 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count"; ...@@ -664,3 +666,9 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count";
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
} }
//
// training utils
//
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
#include "sampling.h" #include "sampling.h"
#include "common.h" #include "common.h"
#include "log.h"
#include <cmath> #include <cmath>
#include <unordered_map> #include <unordered_map>
...@@ -229,51 +230,48 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co ...@@ -229,51 +230,48 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
params.logit_bias.data())); params.logit_bias.data()));
if (params.mirostat == 0) { if (params.mirostat == 0) {
if (params.top_n_sigma >= 0) { for (const auto & cnstr : params.samplers) {
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); switch (cnstr) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp)); case COMMON_SAMPLER_TYPE_DRY:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); {
} else { std::vector<const char *> c_breakers;
for (const auto & cnstr : params.samplers) { c_breakers.reserve(params.dry_sequence_breakers.size());
switch (cnstr) { for (const auto & str : params.dry_sequence_breakers) {
case COMMON_SAMPLER_TYPE_DRY: c_breakers.push_back(str.c_str());
{
std::vector<const char *> c_breakers;
c_breakers.reserve(params.dry_sequence_breakers.size());
for (const auto & str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str());
}
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
} }
break;
case COMMON_SAMPLER_TYPE_TOP_K: llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); }
break; break;
case COMMON_SAMPLER_TYPE_TOP_P: case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
break; break;
case COMMON_SAMPLER_TYPE_MIN_P: case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_XTC: case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
break; break;
case COMMON_SAMPLER_TYPE_TYPICAL_P: case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_TEMPERATURE: case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break; break;
case COMMON_SAMPLER_TYPE_INFILL: case COMMON_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_PENALTIES: case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break; break;
default: case COMMON_SAMPLER_TYPE_INFILL:
GGML_ASSERT(false && "unknown sampler type"); llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
} break;
case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
} }
} }
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
...@@ -475,6 +473,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { ...@@ -475,6 +473,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K: return 'k'; case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y'; case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
case COMMON_SAMPLER_TYPE_TOP_P: return 'p'; case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
case COMMON_SAMPLER_TYPE_MIN_P: return 'm'; case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't'; case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_XTC: return 'x';
...@@ -490,6 +489,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { ...@@ -490,6 +489,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k"; case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p"; case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p"; case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature"; case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_XTC: return "xtc";
...@@ -504,6 +504,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect ...@@ -504,6 +504,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "dry", COMMON_SAMPLER_TYPE_DRY }, { "dry", COMMON_SAMPLER_TYPE_DRY },
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K }, { "top_k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P }, { "top_p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P }, { "min_p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE }, { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
...@@ -517,6 +518,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect ...@@ -517,6 +518,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map { std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K }, { "top-k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P }, { "top-p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P }, { "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
...@@ -533,14 +535,16 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect ...@@ -533,14 +535,16 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
auto sampler = sampler_canonical_name_map.find(name); auto sampler = sampler_canonical_name_map.find(name);
if (sampler != sampler_canonical_name_map.end()) { if (sampler != sampler_canonical_name_map.end()) {
samplers.push_back(sampler->second); samplers.push_back(sampler->second);
} else { continue;
if (allow_alt_names) { }
sampler = sampler_alt_name_map.find(name); if (allow_alt_names) {
if (sampler != sampler_alt_name_map.end()) { sampler = sampler_alt_name_map.find(name);
samplers.push_back(sampler->second); if (sampler != sampler_alt_name_map.end()) {
} samplers.push_back(sampler->second);
continue;
} }
} }
LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str());
} }
return samplers; return samplers;
...@@ -552,6 +556,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri ...@@ -552,6 +556,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
...@@ -566,6 +571,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri ...@@ -566,6 +571,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
const auto sampler = sampler_name_map.find(c); const auto sampler = sampler_name_map.find(c);
if (sampler != sampler_name_map.end()) { if (sampler != sampler_name_map.end()) {
samplers.push_back(sampler->second); samplers.push_back(sampler->second);
} else {
LOG_WRN("%s: unable to match sampler by char '%c'\n", __func__, c);
} }
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "ggml.h" #include "ggml.h"
#include "ggml-cpu.h" #include "ggml-cpu.h"
#include "ggml-backend.h" #include "ggml-backend.h"
#include "ggml-opt.h"
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
...@@ -112,6 +113,7 @@ extern "C" { ...@@ -112,6 +113,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
}; };
enum llama_rope_type { enum llama_rope_type {
...@@ -352,20 +354,19 @@ extern "C" { ...@@ -352,20 +354,19 @@ extern "C" {
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
// Abort callback
// if it returns true, execution of llama_decode() will be aborted
// currently works only with CPU execution
ggml_abort_callback abort_callback;
void * abort_callback_data;
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
// TODO: move at the end of the struct
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embeddings; // if true, extract embeddings (together with logits) bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL] bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool no_perf; // whether to measure performance timings bool no_perf; // whether to measure performance timings
bool op_offload; // whether to offload host tensor operations to device
bool cross_attn; // whether to use cross attention bool cross_attn; // whether to use cross attention
// Abort callback
// if it returns true, execution of llama_decode() will be aborted
// currently works only with CPU execution
ggml_abort_callback abort_callback;
void * abort_callback_data;
}; };
// model quantization parameters // model quantization parameters
...@@ -447,6 +448,10 @@ extern "C" { ...@@ -447,6 +448,10 @@ extern "C" {
size_t n_paths, size_t n_paths,
struct llama_model_params params); struct llama_model_params params);
LLAMA_API void llama_model_save_to_file(
const struct llama_model * model,
const char * path_model);
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model), DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
"use llama_model_free instead"); "use llama_model_free instead");
...@@ -930,14 +935,19 @@ extern "C" { ...@@ -930,14 +935,19 @@ extern "C" {
// Frees a batch of tokens allocated with llama_batch_init() // Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch); LLAMA_API void llama_batch_free(struct llama_batch batch);
// Processes a batch of tokens with the ecoder part of the encoder-decoder model. // Process a batch of tokens.
// Stores the encoder output internally for later use by the decoder cross-attention layers. // In contrast to llama_decode() - this call does not use KV cache.
// For encode-decoder contexts, processes the batch using the encoder.
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
// 0 - success // 0 - success
// < 0 - error. the KV cache state is restored to the state before this call // < 0 - error. the KV cache state is restored to the state before this call
LLAMA_API int32_t llama_encode( LLAMA_API int32_t llama_encode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch); struct llama_batch batch);
// Process a batch of tokens.
// Requires KV cache.
// For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning. // Positive return values does not mean a fatal error, but rather a warning.
// 0 - success // 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
...@@ -1434,6 +1444,37 @@ extern "C" { ...@@ -1434,6 +1444,37 @@ extern "C" {
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
//
// training
//
// function that returns whether or not a given tensor contains trainable parameters
typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata);
// always returns true
LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata);
struct llama_opt_params {
uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters
void * param_filter_ud; // userdata for determining which tensors contain trainable parameters
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
};
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
LLAMA_API void llama_opt_epoch(
struct llama_context * lctx,
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
......
...@@ -253,6 +253,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ ...@@ -253,6 +253,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
std::vector<ggml_backend_buffer_type_t> buft_extra; std::vector<ggml_backend_buffer_type_t> buft_extra;
{ {
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (!cpu_dev) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
...@@ -291,6 +294,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ ...@@ -291,6 +294,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft)); LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (!cpu_dev) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
buft = ggml_backend_dev_buffer_type(cpu_dev); buft = ggml_backend_dev_buffer_type(cpu_dev);
break; break;
......
...@@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { ...@@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
return ubatch; return ubatch;
} }
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
GGML_ASSERT(batch.n_tokens >= 0); GGML_ASSERT(batch.n_tokens >= 0);
this->batch = &batch; this->batch = &batch;
this->n_embd = n_embd; this->n_embd = n_embd;
...@@ -203,6 +203,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim ...@@ -203,6 +203,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
for (size_t i = 0; i < n_tokens; ++i) { for (size_t i = 0; i < n_tokens; ++i) {
ids[i] = i; ids[i] = i;
} }
if (simple_split) { if (simple_split) {
seq.resize(1); seq.resize(1);
llama_sbatch_seq & s = seq[0]; llama_sbatch_seq & s = seq[0];
...@@ -212,6 +213,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim ...@@ -212,6 +213,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
s.length = n_tokens; s.length = n_tokens;
return; return;
} }
std::sort(ids.begin(), ids.end(), std::sort(ids.begin(), ids.end(),
[&batch](size_t a, size_t b) { [&batch](size_t a, size_t b) {
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
...@@ -239,6 +241,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim ...@@ -239,6 +241,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
return n_seq_a > n_seq_b; return n_seq_a > n_seq_b;
} }
); );
// init seq // init seq
llama_sbatch_seq * last_seq = nullptr; llama_sbatch_seq * last_seq = nullptr;
...@@ -262,6 +265,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim ...@@ -262,6 +265,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
seq.push_back(new_seq); seq.push_back(new_seq);
last_seq = &seq.back(); last_seq = &seq.back();
} }
// keep shared prompts first at the end, then sort by length descending. // keep shared prompts first at the end, then sort by length descending.
std::sort(seq.begin(), seq.end(), std::sort(seq.begin(), seq.end(),
[](llama_sbatch_seq & a, llama_sbatch_seq & b) { [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
......
...@@ -70,7 +70,8 @@ struct llama_sbatch { ...@@ -70,7 +70,8 @@ struct llama_sbatch {
// sequence-wise split // sequence-wise split
llama_ubatch split_seq(size_t n_ubatch); llama_ubatch split_seq(size_t n_ubatch);
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); llama_sbatch() = default;
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
}; };
// temporary allocate memory for the input batch if needed // temporary allocate memory for the input batch if needed
......
...@@ -35,6 +35,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = { ...@@ -35,6 +35,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 }, { "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
{ "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
{ "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
{ "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN },
{ "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, { "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
{ "phi4", LLM_CHAT_TEMPLATE_PHI_4 }, { "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
{ "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 }, { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
...@@ -202,19 +203,20 @@ int32_t llm_chat_apply_template( ...@@ -202,19 +203,20 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "<|im_start|>assistant\n"; ss << "<|im_start|>assistant\n";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) { } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) {
// Official mistral 'v7' template // Official mistral 'v7' template
// See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7 // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
// https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken
const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : "";
for (auto message : chat) { for (auto message : chat) {
std::string role(message->role); std::string role(message->role);
std::string content(message->content); std::string content(message->content);
if (role == "system") { if (role == "system") {
ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]"; ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]";
} else if (role == "user") { } else if (role == "user") {
ss << "[INST] " << content << "[/INST]"; ss << "[INST]" << trailing_space << content << "[/INST]";
} } else {
else { ss << trailing_space << content << "</s>";
ss << " " << content << "</s>";
} }
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
...@@ -447,8 +449,16 @@ int32_t llm_chat_apply_template( ...@@ -447,8 +449,16 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "<|assistant|>"; ss << "<|assistant|>";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4 || tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) { } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4) {
ss << "[gMASK]" << "<sop>"; ss << "[gMASK]" << "<sop>";
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content;
}
if (add_ass) {
ss << "<|assistant|>\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
for (auto message : chat) { for (auto message : chat) {
std::string role(message->role); std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content; ss << "<|" << role << "|>" << "\n" << message->content;
......
...@@ -14,6 +14,7 @@ enum llm_chat_template { ...@@ -14,6 +14,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_MISTRAL_V3, LLM_CHAT_TEMPLATE_MISTRAL_V3,
LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
LLM_CHAT_TEMPLATE_MISTRAL_V7, LLM_CHAT_TEMPLATE_MISTRAL_V7,
LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN,
LLM_CHAT_TEMPLATE_PHI_3, LLM_CHAT_TEMPLATE_PHI_3,
LLM_CHAT_TEMPLATE_PHI_4, LLM_CHAT_TEMPLATE_PHI_4,
LLM_CHAT_TEMPLATE_FALCON_3, LLM_CHAT_TEMPLATE_FALCON_3,
......
This diff is collapsed.
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "llama-kv-cache.h" #include "llama-kv-cache.h"
#include "ggml-cpp.h" #include "ggml-cpp.h"
#include "ggml-opt.h"
#include <map> #include <map>
#include <vector> #include <vector>
...@@ -28,7 +29,12 @@ struct llama_context { ...@@ -28,7 +29,12 @@ struct llama_context {
void synchronize(); void synchronize();
const llama_model & get_model() const; const llama_model & get_model() const;
const llama_cparams & get_cparams() const;
ggml_backend_sched_t get_sched() const;
ggml_context * get_ctx_compute() const;
uint32_t n_ctx() const; uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const; uint32_t n_ctx_per_seq() const;
...@@ -130,6 +136,32 @@ struct llama_context { ...@@ -130,6 +136,32 @@ struct llama_context {
llama_perf_context_data perf_get_data() const; llama_perf_context_data perf_get_data() const;
void perf_reset(); void perf_reset();
//
// training
//
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
void opt_epoch(
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
void opt_epoch_iter(
ggml_opt_dataset_t dataset,
ggml_opt_result_t result,
const std::vector<llama_token> & tokens,
const std::vector<llama_token> & labels_sparse,
llama_batch & batch,
ggml_opt_epoch_callback callback,
bool train,
int64_t idata_in_loop,
int64_t ndata_in_loop,
int64_t t_loop_start);
private: private:
// //
// output // output
...@@ -139,50 +171,30 @@ private: ...@@ -139,50 +171,30 @@ private:
// Returns max number of outputs for which space was reserved. // Returns max number of outputs for which space was reserved.
int32_t output_reserve(int32_t n_outputs); int32_t output_reserve(int32_t n_outputs);
// make the outputs have the same order they had in the user-provided batch
// TODO: maybe remove this
void output_reorder();
// //
// graph // graph
// //
public:
int32_t graph_max_nodes() const; int32_t graph_max_nodes() const;
// zero-out inputs and create the ctx_compute for the compute graph // zero-out inputs and create the ctx_compute for the compute graph
ggml_cgraph * graph_init(); ggml_cgraph * graph_init();
// returns the result of ggml_backend_sched_graph_compute_async execution
ggml_status graph_compute(
ggml_cgraph * gf,
bool batched);
private:
llm_graph_result_ptr graph_build( llm_graph_result_ptr graph_build(
ggml_context * ctx, ggml_context * ctx,
ggml_cgraph * gf, ggml_cgraph * gf,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
llm_graph_type gtype); llm_graph_type gtype);
// returns the result of ggml_backend_sched_graph_compute_async execution
ggml_status graph_compute(
ggml_cgraph * gf,
bool batched);
llm_graph_cb graph_get_cb() const; llm_graph_cb graph_get_cb() const;
// used by kv_self_update()
ggml_tensor * build_rope_shift(
ggml_context * ctx0,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale) const;
llm_graph_result_ptr build_kv_self_shift(
ggml_context * ctx0,
ggml_cgraph * gf) const;
llm_graph_result_ptr build_kv_self_defrag(
ggml_context * ctx0,
ggml_cgraph * gf,
const std::vector<struct llama_kv_defrag_move> & moves) const;
// TODO: read/write lora adapters and cvec // TODO: read/write lora adapters and cvec
size_t state_write_data(llama_io_write_i & io); size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io); size_t state_read_data (llama_io_read_i & io);
...@@ -199,14 +211,10 @@ private: ...@@ -199,14 +211,10 @@ private:
llama_cparams cparams; llama_cparams cparams;
llama_adapter_cvec cvec; llama_adapter_cvec cvec;
llama_adapter_loras loras; llama_adapter_loras loras;
llama_sbatch sbatch;
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
std::unique_ptr<llama_kv_cache_unified> kv_self; std::unique_ptr<llama_memory_i> memory;
// TODO: remove
bool logits_all = false;
// decode output (2-dimensional array: [n_outputs][n_vocab]) // decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits size_t logits_size = 0; // capacity (of floats) for logits
...@@ -233,6 +241,9 @@ private: ...@@ -233,6 +241,9 @@ private:
ggml_context_ptr ctx_compute; ggml_context_ptr ctx_compute;
// training
ggml_opt_context_t opt_ctx = nullptr;
ggml_threadpool_t threadpool = nullptr; ggml_threadpool_t threadpool = nullptr;
ggml_threadpool_t threadpool_batch = nullptr; ggml_threadpool_t threadpool_batch = nullptr;
......
...@@ -29,8 +29,9 @@ struct llama_cparams { ...@@ -29,8 +29,9 @@ struct llama_cparams {
bool offload_kqv; bool offload_kqv;
bool flash_attn; bool flash_attn;
bool no_perf; bool no_perf;
bool cross_attn;
bool warmup; bool warmup;
bool op_offload;
bool cross_attn;
enum llama_pooling_type pooling_type; enum llama_pooling_type pooling_type;
......
...@@ -284,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { ...@@ -284,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_kv; ++i) { for (uint32_t i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self->head; data[i] = kv_self->s_copy(i);
//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
// prevent out-of-bound sources
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
kv_cell.src = cell_id;
}
data[i] = kv_cell.src;
// TODO: do not mutate the KV cache
// ensure copy only happens once
if (kv_cell.src != (int32_t) cell_id) {
kv_cell.src = cell_id;
}
} }
} }
} }
...@@ -317,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { ...@@ -317,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
// clear unused states // clear unused states
for (int i = 0; i < n_kv; ++i) { for (int i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self->head; data[i] = kv_self->s_mask(i);
//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
data[i] = (float) (kv_cell.src >= 0);
// only clear once
if (kv_cell.src < 0) {
kv_cell.src = cell_id;
}
} }
} }
} }
...@@ -816,7 +788,7 @@ ggml_tensor * llm_graph_context::build_ffn( ...@@ -816,7 +788,7 @@ ggml_tensor * llm_graph_context::build_ffn(
} break; } break;
} }
if (type_gate == LLM_FFN_PAR) { if (gate && type_gate == LLM_FFN_PAR) {
cur = ggml_mul(ctx0, cur, tmp); cur = ggml_mul(ctx0, cur, tmp);
cb(cur, "ffn_gate_par", il); cb(cur, "ffn_gate_par", il);
} }
...@@ -1005,6 +977,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { ...@@ -1005,6 +977,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
//cb(inp->tokens, "inp_tokens", -1); //cb(inp->tokens, "inp_tokens", -1);
ggml_set_input(inp->tokens); ggml_set_input(inp->tokens);
res->t_tokens = inp->tokens;
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens); cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
...@@ -1111,7 +1084,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { ...@@ -1111,7 +1084,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
} }
ggml_tensor * llm_graph_context::build_inp_s_copy() const { ggml_tensor * llm_graph_context::build_inp_s_copy() const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory); const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self); auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
...@@ -1128,7 +1101,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const { ...@@ -1128,7 +1101,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
} }
ggml_tensor * llm_graph_context::build_inp_s_mask() const { ggml_tensor * llm_graph_context::build_inp_s_mask() const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory); const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self); auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
...@@ -1261,8 +1234,19 @@ ggml_tensor * llm_graph_context::build_attn_mha( ...@@ -1261,8 +1234,19 @@ ggml_tensor * llm_graph_context::build_attn_mha(
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
if (v_mla) { if (v_mla) {
#if 0
// v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
// However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens); cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
cur = ggml_mul_mat(ctx0, v_mla, cur); cur = ggml_mul_mat(ctx0, v_mla, cur);
#else
// It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
// The permutations are noops and only change how the tensor data is interpreted.
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_mul_mat(ctx0, v_mla, cur);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
#endif
} }
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
...@@ -1442,8 +1426,6 @@ ggml_tensor * llm_graph_context::build_attn( ...@@ -1442,8 +1426,6 @@ ggml_tensor * llm_graph_context::build_attn(
// store to KV cache // store to KV cache
{ {
GGML_ASSERT(!kv_self->recurrent);
const auto kv_head = kv_self->head; const auto kv_head = kv_self->head;
GGML_ASSERT(kv_self->size == n_ctx); GGML_ASSERT(kv_self->size == n_ctx);
...@@ -1612,7 +1594,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state( ...@@ -1612,7 +1594,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
ggml_tensor * state_mask, ggml_tensor * state_mask,
int32_t n_state, int32_t n_state,
int32_t n_seqs) const { int32_t n_seqs) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory); const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto n_kv = kv_self->n; const auto n_kv = kv_self->n;
const auto kv_head = kv_self->head; const auto kv_head = kv_self->head;
...@@ -1644,7 +1626,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ...@@ -1644,7 +1626,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_tensor * state_mask, ggml_tensor * state_mask,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
int il) const { int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory); const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto token_shift_count = hparams.token_shift_count; const auto token_shift_count = hparams.token_shift_count;
...@@ -1665,7 +1647,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( ...@@ -1665,7 +1647,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
ggml_tensor * token_shift, ggml_tensor * token_shift,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
int il) const { int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory); const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto token_shift_count = hparams.token_shift_count; const auto token_shift_count = hparams.token_shift_count;
const auto n_embd = hparams.n_embd; const auto n_embd = hparams.n_embd;
......
...@@ -19,6 +19,7 @@ struct llama_cparams; ...@@ -19,6 +19,7 @@ struct llama_cparams;
class llama_memory_i; class llama_memory_i;
class llama_kv_cache_unified; class llama_kv_cache_unified;
class llama_kv_cache_recurrent;
// certain models (typically multi-modal) can produce different types of graphs // certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type { enum llm_graph_type {
...@@ -187,26 +188,26 @@ public: ...@@ -187,26 +188,26 @@ public:
class llm_graph_input_s_copy : public llm_graph_input_i { class llm_graph_input_s_copy : public llm_graph_input_i {
public: public:
llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_copy() = default; virtual ~llm_graph_input_s_copy() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size] ggml_tensor * s_copy; // I32 [kv_size]
const llama_kv_cache_unified * kv_self; const llama_kv_cache_recurrent * kv_self;
}; };
class llm_graph_input_s_mask : public llm_graph_input_i { class llm_graph_input_s_mask : public llm_graph_input_i {
public: public:
llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_mask() = default; virtual ~llm_graph_input_s_mask() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_mask; // F32 [1, n_kv] ggml_tensor * s_mask; // F32 [1, n_kv]
const llama_kv_cache_unified * kv_self; const llama_kv_cache_recurrent * kv_self;
}; };
class llm_graph_input_cross_embd : public llm_graph_input_i { class llm_graph_input_cross_embd : public llm_graph_input_i {
...@@ -308,6 +309,7 @@ class llm_graph_result_i { ...@@ -308,6 +309,7 @@ class llm_graph_result_i {
public: public:
virtual ~llm_graph_result_i() = default; virtual ~llm_graph_result_i() = default;
virtual ggml_tensor * get_tokens() = 0;
virtual ggml_tensor * get_logits() = 0; virtual ggml_tensor * get_logits() = 0;
virtual ggml_tensor * get_embd() = 0; virtual ggml_tensor * get_embd() = 0;
virtual ggml_tensor * get_embd_pooled() = 0; virtual ggml_tensor * get_embd_pooled() = 0;
...@@ -322,6 +324,7 @@ class llm_graph_result : public llm_graph_result_i { ...@@ -322,6 +324,7 @@ class llm_graph_result : public llm_graph_result_i {
public: public:
virtual ~llm_graph_result() = default; virtual ~llm_graph_result() = default;
ggml_tensor * get_tokens() override { return t_tokens; }
ggml_tensor * get_logits() override { return t_logits; } ggml_tensor * get_logits() override { return t_logits; }
ggml_tensor * get_embd() override { return t_embd; } ggml_tensor * get_embd() override { return t_embd; }
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; } ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
...@@ -338,6 +341,7 @@ public: ...@@ -338,6 +341,7 @@ public:
} }
// important graph nodes // important graph nodes
ggml_tensor * t_tokens = nullptr;
ggml_tensor * t_logits = nullptr; ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr; ggml_tensor * t_embd_pooled = nullptr;
...@@ -361,8 +365,8 @@ struct llm_graph_params { ...@@ -361,8 +365,8 @@ struct llm_graph_params {
const llama_cparams & cparams; const llama_cparams & cparams;
const llama_ubatch & ubatch; const llama_ubatch & ubatch;
ggml_backend_sched * sched; ggml_backend_sched_t sched;
ggml_backend * backend_cpu; ggml_backend_t backend_cpu;
const llama_adapter_cvec * cvec; const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras; const llama_adapter_loras * loras;
...@@ -413,9 +417,9 @@ struct llm_graph_context { ...@@ -413,9 +417,9 @@ struct llm_graph_context {
ggml_context * ctx0 = nullptr; ggml_context * ctx0 = nullptr;
ggml_backend_sched * sched; ggml_backend_sched_t sched;
ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
const llama_adapter_cvec * cvec; const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras; const llama_adapter_loras * loras;
......
This diff is collapsed.
...@@ -2,32 +2,72 @@ ...@@ -2,32 +2,72 @@
#include "llama.h" #include "llama.h"
#include "llama-io.h" #include "llama-io.h"
#include "llama-graph.h"
#include "llama-memory.h" #include "llama-memory.h"
#include "ggml-cpp.h" #include "ggml-cpp.h"
#include <functional>
#include <set> #include <set>
#include <vector> #include <vector>
struct llama_cparams; struct llama_cparams;
struct llama_hparams; struct llama_hparams;
struct llama_ubatch; struct llama_ubatch;
struct llama_sbatch;
struct llama_model;
struct llama_context;
struct llama_kv_cache : public llama_memory_i { struct llama_kv_cache : public llama_memory_i {
using llama_memory_i::llama_memory_i; virtual ~llama_kv_cache() = default;
virtual void restore() = 0; // call if batch processing fails - restores the cache state // call if batch processing fails - restores the cache state
virtual void commit() = 0; // call after successful batch processing - clears any pending state virtual void restore() = 0;
virtual int32_t get_n_tokens() const = 0; // call after successful batch processing - clears any pending state
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache virtual void commit() = 0;
virtual bool get_can_shift() const = 0; // process any pending defrag/shift/etc. operations
// optionally call once before processing a new batch
virtual bool update(llama_context & lctx) = 0;
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
virtual void defrag_sched(float thold) = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual void set_full() = 0;
//
// batch processing
//
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
// different KV caches require different batch splitting strategies
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
// find an empty slot of size "n_tokens" in the cache
virtual bool find_slot(const llama_ubatch & batch) = 0;
// getters
virtual int32_t get_n_tokens() const = 0;
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
virtual llama_pos get_pos_max() const = 0;
virtual bool get_can_shift() const = 0;
bool get_can_edit() const override { return get_can_shift(); } bool get_can_edit() const override { return get_can_shift(); }
//
// state write/read
//
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
}; };
//
// llama_kv_cache_guard
//
struct llama_kv_cache_guard { struct llama_kv_cache_guard {
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {} llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
...@@ -42,7 +82,7 @@ struct llama_kv_cache_guard { ...@@ -42,7 +82,7 @@ struct llama_kv_cache_guard {
private: private:
llama_kv_cache * kv; llama_kv_cache * kv;
}; };
// block of KV slots to move when defragging // block of KV slots to move when defragging
struct llama_kv_defrag_move { struct llama_kv_defrag_move {
uint32_t src; uint32_t src;
...@@ -50,65 +90,50 @@ struct llama_kv_defrag_move { ...@@ -50,65 +90,50 @@ struct llama_kv_defrag_move {
uint32_t len; uint32_t len;
}; };
struct llama_kv_cell { //
llama_pos pos = -1; // llama_kv_cache_unified
llama_pos delta = 0; //
int32_t src = -1; // used by recurrent state models to copy states
int32_t tail = -1;
std::set<llama_seq_id> seq_id; // TODO: add notion of max sequences
class llama_kv_cache_unified : public llama_kv_cache {
public:
struct kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
bool has_seq_id(const llama_seq_id & id) const { std::set<llama_seq_id> seq_id;
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const { bool has_seq_id(const llama_seq_id & id) const {
return seq_id.empty(); return seq_id.find(id) != seq_id.end();
} }
bool is_same_seq(const llama_kv_cell & other) const { bool is_empty() const {
return seq_id == other.seq_id; return seq_id.empty();
} }
};
// ring-buffer of cached KV data bool is_same_seq(const kv_cell & other) const {
// TODO: pimpl return seq_id == other.seq_id;
// TODO: add notion of max sequences }
class llama_kv_cache_unified : public llama_kv_cache {
public:
// can be used to query data from the model if needed
struct callbacks {
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
}; };
llama_kv_cache_unified( static uint32_t get_padding(const llama_cparams & cparams);
const llama_hparams & hparams,
callbacks cbs);
virtual ~llama_kv_cache_unified() = default; llama_kv_cache_unified(
const llama_model & model,
// TODO: become constructor
bool init(
const llama_model & model, // TODO: do not reference the model
const llama_cparams & cparams,
ggml_type type_k, ggml_type type_k,
ggml_type type_v, ggml_type type_v,
bool v_trans,
bool offload,
uint32_t kv_size, uint32_t kv_size,
bool offload); uint32_t padding);
int32_t get_n_tokens() const override; ~llama_kv_cache_unified() = default;
int32_t get_used_cells() const override;
size_t total_size() const; //
// llama_memory_i
// TODO: better data structures to reduce the cost of this operation //
llama_pos pos_max() const;
void clear() override; void clear() override;
void defrag() override;
virtual void restore() override;
virtual void commit() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
...@@ -118,25 +143,76 @@ public: ...@@ -118,25 +143,76 @@ public:
llama_pos seq_pos_max(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override;
bool get_can_shift() const override; //
// llama_kv_cache
//
void restore() override;
void commit() override;
bool update(llama_context & ctx) override;
void defrag_sched(float thold) override;
void set_full() override;
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
// find an empty slot of size "n_tokens" in the cache
// updates the cache head // updates the cache head
// Note: On success, it's important that cache.head points // Note: On success, it's important that cache.head points
// to the first cell of the slot. // to the first cell of the slot.
bool find_slot(const llama_ubatch & batch); bool find_slot(const llama_ubatch & batch) override;
// TODO: maybe not needed int32_t get_n_tokens() const override;
uint32_t get_padding(const llama_cparams & cparams) const; int32_t get_used_cells() const override;
// find how many cells are currently in use // TODO: better data structures to reduce the cost of this operation
uint32_t cell_max() const; llama_pos get_pos_max() const override;
size_t size_k_bytes() const; bool get_can_shift() const override;
size_t size_v_bytes() const;
// defrag // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t used = 0; // used cells (i.e. at least one seq_id)
// computed before each graph build
uint32_t n = 0;
std::vector<kv_cell> cells;
std::vector<ggml_tensor *> k_l; // per layer
std::vector<ggml_tensor *> v_l;
private:
const llama_model & model;
const llama_hparams & hparams;
bool has_shift = false;
bool do_defrag = false;
bool v_trans = true; // the value tensor is transposed
bool can_shift = false;
// required padding
uint32_t padding = 1;
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
// defrag
struct { struct {
std::vector<llama_kv_defrag_move> moves; std::vector<llama_kv_defrag_move> moves;
} defrag_info; } defrag_info;
...@@ -145,7 +221,6 @@ public: ...@@ -145,7 +221,6 @@ public:
bool defrag_prepare(int32_t n_max_nodes); bool defrag_prepare(int32_t n_max_nodes);
// commit/restore cache // commit/restore cache
struct slot_range { struct slot_range {
uint32_t c0 = 0; // note: these are cell indices, not sequence positions uint32_t c0 = 0; // note: these are cell indices, not sequence positions
uint32_t c1 = 0; uint32_t c1 = 0;
...@@ -156,25 +231,125 @@ public: ...@@ -156,25 +231,125 @@ public:
std::vector<slot_range> ranges; std::vector<slot_range> ranges;
} pending; } pending;
// state write/load // find how many cells are currently in use
uint32_t cell_max() const;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const; size_t total_size() const;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
// members size_t size_k_bytes() const;
size_t size_v_bytes() const;
const llama_hparams & hparams; ggml_tensor * build_rope_shift(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale) const;
llm_graph_result_ptr build_graph_shift(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_cgraph * gf) const;
llm_graph_result_ptr build_graph_defrag(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_cgraph * gf,
const std::vector<llama_kv_defrag_move> & moves) const;
callbacks cbs; void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
bool has_shift = false; bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool do_defrag = false; bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
// TODO: remove this and implement llama_kv_cache_recurrent instead //
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token // llama_kv_cache_recurrent
//
bool v_trans = true; // the value tensor is transposed class llama_kv_cache_recurrent : public llama_kv_cache {
bool can_shift = false; public:
struct kv_cell {
llama_pos pos = -1;
int32_t src = -1; // used to copy states
int32_t tail = -1;
std::set<llama_seq_id> seq_id;
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const {
return seq_id.empty();
}
bool is_same_seq(const kv_cell & other) const {
return seq_id == other.seq_id;
}
};
llama_kv_cache_recurrent(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size);
~llama_kv_cache_recurrent() = default;
//
// llama_memory_i
//
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
void restore() override;
void commit() override;
bool update(llama_context & lctx) override;
void defrag_sched(float thold) override;
void set_full() override;
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
bool find_slot(const llama_ubatch & batch) override;
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;
// TODO: better data structures to reduce the cost of this operation
llama_pos get_pos_max() const override;
bool get_can_shift() const override;
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
int32_t s_copy(int i) const;
float s_mask(int i) const;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
// Note: The value of head isn't only used to optimize searching // Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it // for a free KV slot. llama_decode_impl also uses it, so it
...@@ -186,18 +361,41 @@ public: ...@@ -186,18 +361,41 @@ public:
// computed before each graph build // computed before each graph build
uint32_t n = 0; uint32_t n = 0;
std::vector<llama_kv_cell> cells; std::vector<kv_cell> cells;
std::vector<ggml_tensor *> k_l; // per layer std::vector<ggml_tensor *> k_l; // per layer
std::vector<ggml_tensor *> v_l; std::vector<ggml_tensor *> v_l;
private: private:
//const llama_model & model;
const llama_hparams & hparams;
// commit/restore cache
// TODO: rework for recurrent cache
struct slot_range {
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
uint32_t c1 = 0;
};
// pending cell updates that are not yet committed
struct {
std::vector<slot_range> ranges;
} pending;
ggml_type type_k = GGML_TYPE_F16; ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16;
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs; std::vector<ggml_backend_buffer_ptr> bufs;
// find how many cells are currently in use
uint32_t cell_max() const;
size_t total_size() const;
size_t size_k_bytes() const;
size_t size_v_bytes() const;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const; void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const; void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
...@@ -205,11 +403,6 @@ private: ...@@ -205,11 +403,6 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count); bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
}; };
// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
//class llama_kv_cache_recurrent : public llama_kv_cache_unified {
//public:
// using llama_kv_cache_unified::llama_kv_cache_unified;
//};
// //
// kv cache view // kv cache view
......
...@@ -2,12 +2,22 @@ ...@@ -2,12 +2,22 @@
#include "llama.h" #include "llama.h"
struct llama_memory_params {
// kv cache
ggml_type type_k;
ggml_type type_v;
// parameters for other types of memory
// ...
};
// general concept of LLM memory // general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types // the KV cache is a type of LLM memory, but there can be other types
class llama_memory_i { class llama_memory_i {
public: public:
virtual ~llama_memory_i() = default;
virtual void clear() = 0; virtual void clear() = 0;
virtual void defrag() = 0;
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment