Commit b2b270ad authored by Devon Rifkin's avatar Devon Rifkin
Browse files

Merge branch 'main' into drifkin/array-head-count-simple

parents 20c5fd39 2bb69b40
...@@ -344,7 +344,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) ...@@ -344,7 +344,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor, _ := context.FromFloatSlice(test.in, test.inShape...) tensor := context.FromFloatSlice(test.in, test.inShape...)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
out, _, mask := cache.Get(context) out, _, mask := cache.Get(context)
...@@ -386,7 +386,7 @@ func TestCanResume(t *testing.T) { ...@@ -386,7 +386,7 @@ func TestCanResume(t *testing.T) {
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4) tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
// with window size 4, nothing has slid out of the window yet // with window size 4, nothing has slid out of the window yet
...@@ -413,7 +413,7 @@ func TestCanResume(t *testing.T) { ...@@ -413,7 +413,7 @@ func TestCanResume(t *testing.T) {
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2) tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows // only the latest position has overlapping windows
...@@ -470,24 +470,24 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { ...@@ -470,24 +470,24 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return c.Empty(dtype, shape...) return c.Empty(dtype, shape...)
} }
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { func (c *testContext) FromFloatSlice(s []float32, shape ...int) ml.Tensor {
t := c.Empty(ml.DTypeF32, shape...).(*testTensor) t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
copy(t.data, s) copy(t.data, s)
return t, nil return t
} }
func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { func (c *testContext) FromIntSlice(s []int32, shape ...int) ml.Tensor {
f := make([]float32, len(s)) f := make([]float32, len(s))
for i := range f { for i := range f {
f[i] = float32(s[i]) f[i] = float32(s[i])
} }
out, _ := c.FromFloatSlice(f, shape...) out := c.FromFloatSlice(f, shape...)
out.(*testTensor).dtype = ml.DTypeI32 out.(*testTensor).dtype = ml.DTypeI32
return out, nil return out
} }
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
...@@ -496,7 +496,7 @@ func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tenso ...@@ -496,7 +496,7 @@ func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tenso
s = append(s, i) s = append(s, i)
} }
out, _ := c.FromFloatSlice(s, len(s)) out := c.FromFloatSlice(s, len(s))
out.(*testTensor).dtype = dtype out.(*testTensor).dtype = dtype
return out return out
} }
...@@ -508,7 +508,7 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c } ...@@ -508,7 +508,7 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {} func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) Reserve() error { return nil } func (c *testContext) Reserve() {}
func (c *testContext) MaxGraphNodes() int { func (c *testContext) MaxGraphNodes() int {
return 10 return 10
......
int LLAMA_BUILD_NUMBER = 0; int LLAMA_BUILD_NUMBER = 0;
char const *LLAMA_COMMIT = "e1e8e0991ffd9e99a445c6812bb519d5bac9f4b5"; char const *LLAMA_COMMIT = "de4c07f93783a1a96456a44dc16b9db538ee1618";
char const *LLAMA_COMPILER = ""; char const *LLAMA_COMPILER = "";
char const *LLAMA_BUILD_TARGET = ""; char const *LLAMA_BUILD_TARGET = "";
...@@ -10,11 +10,11 @@ include common/stb_image.* ...@@ -10,11 +10,11 @@ include common/stb_image.*
include include/ include include/
include include/llama.* include include/llama.*
include include/llama-*.* include include/llama-*.*
include examples/ include tools/
include examples/llava/ include tools/mtmd/
include examples/llava/clip.* include tools/mtmd/clip.*
include examples/llava/clip-impl.* include tools/mtmd/clip-impl.*
include examples/llava/llava.* include tools/mtmd/llava.*
include src/ include src/
include src/llama.* include src/llama.*
include src/llama-*.* include src/llama-*.*
......
...@@ -1096,7 +1096,6 @@ struct llama_context_params common_context_params_to_llama(const common_params & ...@@ -1096,7 +1096,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.n_threads = params.cpuparams.n_threads; cparams.n_threads = params.cpuparams.n_threads;
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ? cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
params.cpuparams.n_threads : params.cpuparams_batch.n_threads; params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
cparams.logits_all = params.logits_all;
cparams.embeddings = params.embedding; cparams.embeddings = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type; cparams.rope_scaling_type = params.rope_scaling_type;
cparams.rope_freq_base = params.rope_freq_base; cparams.rope_freq_base = params.rope_freq_base;
...@@ -1114,6 +1113,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & ...@@ -1114,6 +1113,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.offload_kqv = !params.no_kv_offload; cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn; cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf; cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
if (params.reranking) { if (params.reranking) {
cparams.embeddings = true; cparams.embeddings = true;
...@@ -1565,3 +1565,20 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c ...@@ -1565,3 +1565,20 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
return result; return result;
} }
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) {
const int64_t ne_datapoint = llama_n_ctx(ctx);
const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride;
ggml_opt_dataset_t result = ggml_opt_dataset_init(
GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1);
llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data;
llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data;
for (int64_t idata = 0; idata < ndata; ++idata) {
memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token));
memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token));
}
return result;
}
...@@ -66,7 +66,6 @@ enum llama_example { ...@@ -66,7 +66,6 @@ enum llama_example {
LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_COMMON,
LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SPECULATIVE,
LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MAIN,
LLAMA_EXAMPLE_INFILL,
LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_EMBEDDING,
LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_PERPLEXITY,
LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_RETRIEVAL,
...@@ -96,6 +95,7 @@ enum common_sampler_type { ...@@ -96,6 +95,7 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_XTC = 8, COMMON_SAMPLER_TYPE_XTC = 8,
COMMON_SAMPLER_TYPE_INFILL = 9, COMMON_SAMPLER_TYPE_INFILL = 9,
COMMON_SAMPLER_TYPE_PENALTIES = 10, COMMON_SAMPLER_TYPE_PENALTIES = 10,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
}; };
// dimensionality reduction methods, used by cvector-generator // dimensionality reduction methods, used by cvector-generator
...@@ -161,6 +161,7 @@ struct common_params_sampling { ...@@ -161,6 +161,7 @@ struct common_params_sampling {
std::vector<enum common_sampler_type> samplers = { std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_PENALTIES, COMMON_SAMPLER_TYPE_PENALTIES,
COMMON_SAMPLER_TYPE_DRY, COMMON_SAMPLER_TYPE_DRY,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
COMMON_SAMPLER_TYPE_TOP_K, COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TYPICAL_P, COMMON_SAMPLER_TYPE_TYPICAL_P,
COMMON_SAMPLER_TYPE_TOP_P, COMMON_SAMPLER_TYPE_TOP_P,
...@@ -323,7 +324,6 @@ struct common_params { ...@@ -323,7 +324,6 @@ struct common_params {
bool ctx_shift = true; // context shift on inifinite text generation bool ctx_shift = true; // context shift on inifinite text generation
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool logits_all = false; // return logits for all tokens in the batch
bool use_mmap = true; // use mmap for faster loads bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory bool use_mlock = false; // use mlock to keep model in memory
bool verbose_prompt = false; // print prompt tokens before generation bool verbose_prompt = false; // print prompt tokens before generation
...@@ -332,6 +332,7 @@ struct common_params { ...@@ -332,6 +332,7 @@ struct common_params {
bool no_kv_offload = false; // disable KV offloading bool no_kv_offload = false; // disable KV offloading
bool warmup = true; // warmup run bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data bool check_tensors = false; // validate tensor data
bool no_op_offload = false; // globally disable offload host tensor operations to device
bool single_turn = false; // single turn chat conversation bool single_turn = false; // single turn chat conversation
...@@ -340,7 +341,7 @@ struct common_params { ...@@ -340,7 +341,7 @@ struct common_params {
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO; common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
// multimodal models (see examples/llava) // multimodal models (see tools/mtmd)
struct common_params_model mmproj; struct common_params_model mmproj;
bool mmproj_use_gpu = true; // use GPU for multimodal model bool mmproj_use_gpu = true; // use GPU for multimodal model
bool no_mmproj = false; // explicitly disable multimodal model bool no_mmproj = false; // explicitly disable multimodal model
...@@ -409,13 +410,14 @@ struct common_params { ...@@ -409,13 +410,14 @@ struct common_params {
bool process_output = false; // collect data for the output tensor bool process_output = false; // collect data for the output tensor
bool compute_ppl = true; // whether to compute perplexity bool compute_ppl = true; // whether to compute perplexity
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
// cvector-generator params // cvector-generator params
int n_pca_batch = 100; int n_pca_batch = 100;
int n_pca_iterations = 1000; int n_pca_iterations = 1000;
dimre_method cvector_dimre_method = DIMRE_METHOD_PCA; dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
std::string cvector_positive_file = "examples/cvector-generator/positive.txt"; std::string cvector_positive_file = "tools/cvector-generator/positive.txt";
std::string cvector_negative_file = "examples/cvector-generator/negative.txt"; std::string cvector_negative_file = "tools/cvector-generator/negative.txt";
bool spm_infill = false; // suffix/prefix/middle pattern for infill bool spm_infill = false; // suffix/prefix/middle pattern for infill
...@@ -664,3 +666,9 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count"; ...@@ -664,3 +666,9 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count";
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
} }
//
// training utils
//
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
#include "sampling.h" #include "sampling.h"
#include "common.h" #include "common.h"
#include "log.h"
#include <cmath> #include <cmath>
#include <unordered_map> #include <unordered_map>
...@@ -229,51 +230,48 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co ...@@ -229,51 +230,48 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
params.logit_bias.data())); params.logit_bias.data()));
if (params.mirostat == 0) { if (params.mirostat == 0) {
if (params.top_n_sigma >= 0) { for (const auto & cnstr : params.samplers) {
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); switch (cnstr) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp)); case COMMON_SAMPLER_TYPE_DRY:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); {
} else { std::vector<const char *> c_breakers;
for (const auto & cnstr : params.samplers) { c_breakers.reserve(params.dry_sequence_breakers.size());
switch (cnstr) { for (const auto & str : params.dry_sequence_breakers) {
case COMMON_SAMPLER_TYPE_DRY: c_breakers.push_back(str.c_str());
{
std::vector<const char *> c_breakers;
c_breakers.reserve(params.dry_sequence_breakers.size());
for (const auto & str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str());
}
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
} }
break;
case COMMON_SAMPLER_TYPE_TOP_K: llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); }
break; break;
case COMMON_SAMPLER_TYPE_TOP_P: case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
break; break;
case COMMON_SAMPLER_TYPE_MIN_P: case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_XTC: case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
break; break;
case COMMON_SAMPLER_TYPE_TYPICAL_P: case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_TEMPERATURE: case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break; break;
case COMMON_SAMPLER_TYPE_INFILL: case COMMON_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_PENALTIES: case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break; break;
default: case COMMON_SAMPLER_TYPE_INFILL:
GGML_ASSERT(false && "unknown sampler type"); llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
} break;
case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
} }
} }
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
...@@ -475,6 +473,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { ...@@ -475,6 +473,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K: return 'k'; case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y'; case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
case COMMON_SAMPLER_TYPE_TOP_P: return 'p'; case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
case COMMON_SAMPLER_TYPE_MIN_P: return 'm'; case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't'; case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_XTC: return 'x';
...@@ -490,6 +489,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { ...@@ -490,6 +489,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k"; case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p"; case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p"; case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature"; case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_XTC: return "xtc";
...@@ -504,6 +504,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect ...@@ -504,6 +504,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "dry", COMMON_SAMPLER_TYPE_DRY }, { "dry", COMMON_SAMPLER_TYPE_DRY },
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K }, { "top_k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P }, { "top_p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P }, { "min_p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE }, { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
...@@ -517,6 +518,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect ...@@ -517,6 +518,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map { std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K }, { "top-k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P }, { "top-p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P }, { "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
...@@ -533,14 +535,16 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect ...@@ -533,14 +535,16 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
auto sampler = sampler_canonical_name_map.find(name); auto sampler = sampler_canonical_name_map.find(name);
if (sampler != sampler_canonical_name_map.end()) { if (sampler != sampler_canonical_name_map.end()) {
samplers.push_back(sampler->second); samplers.push_back(sampler->second);
} else { continue;
if (allow_alt_names) { }
sampler = sampler_alt_name_map.find(name); if (allow_alt_names) {
if (sampler != sampler_alt_name_map.end()) { sampler = sampler_alt_name_map.find(name);
samplers.push_back(sampler->second); if (sampler != sampler_alt_name_map.end()) {
} samplers.push_back(sampler->second);
continue;
} }
} }
LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str());
} }
return samplers; return samplers;
...@@ -552,6 +556,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri ...@@ -552,6 +556,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
...@@ -566,6 +571,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri ...@@ -566,6 +571,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
const auto sampler = sampler_name_map.find(c); const auto sampler = sampler_name_map.find(c);
if (sampler != sampler_name_map.end()) { if (sampler != sampler_name_map.end()) {
samplers.push_back(sampler->second); samplers.push_back(sampler->second);
} else {
LOG_WRN("%s: unable to match sampler by char '%c'\n", __func__, c);
} }
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "ggml.h" #include "ggml.h"
#include "ggml-cpu.h" #include "ggml-cpu.h"
#include "ggml-backend.h" #include "ggml-backend.h"
#include "ggml-opt.h"
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
...@@ -112,6 +113,7 @@ extern "C" { ...@@ -112,6 +113,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
}; };
enum llama_rope_type { enum llama_rope_type {
...@@ -256,7 +258,6 @@ extern "C" { ...@@ -256,7 +258,6 @@ extern "C" {
llama_token * token; llama_token * token;
float * embd; float * embd;
int32_t n_embd;
llama_pos * pos; llama_pos * pos;
int32_t * n_seq_id; int32_t * n_seq_id;
llama_seq_id ** seq_id; llama_seq_id ** seq_id;
...@@ -352,20 +353,18 @@ extern "C" { ...@@ -352,20 +353,18 @@ extern "C" {
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
// TODO: move at the end of the struct
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool no_perf; // whether to measure performance timings
bool cross_attn; // whether to use cross attention
// Abort callback // Abort callback
// if it returns true, execution of llama_decode() will be aborted // if it returns true, execution of llama_decode() will be aborted
// currently works only with CPU execution // currently works only with CPU execution
ggml_abort_callback abort_callback; ggml_abort_callback abort_callback;
void * abort_callback_data; void * abort_callback_data;
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool no_perf; // whether to measure performance timings
bool op_offload; // whether to offload host tensor operations to device
}; };
// model quantization parameters // model quantization parameters
...@@ -447,6 +446,10 @@ extern "C" { ...@@ -447,6 +446,10 @@ extern "C" {
size_t n_paths, size_t n_paths,
struct llama_model_params params); struct llama_model_params params);
LLAMA_API void llama_model_save_to_file(
const struct llama_model * model,
const char * path_model);
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model), DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
"use llama_model_free instead"); "use llama_model_free instead");
...@@ -461,10 +464,6 @@ extern "C" { ...@@ -461,10 +464,6 @@ extern "C" {
struct llama_context_params params), struct llama_context_params params),
"use llama_init_from_model instead"); "use llama_init_from_model instead");
// TODO (jmorganca): this should most likely be passed in as part of a batch
// and not set on the context for all batches.
LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
// Frees all allocated memory // Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx); LLAMA_API void llama_free(struct llama_context * ctx);
...@@ -930,14 +929,19 @@ extern "C" { ...@@ -930,14 +929,19 @@ extern "C" {
// Frees a batch of tokens allocated with llama_batch_init() // Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch); LLAMA_API void llama_batch_free(struct llama_batch batch);
// Processes a batch of tokens with the ecoder part of the encoder-decoder model. // Process a batch of tokens.
// Stores the encoder output internally for later use by the decoder cross-attention layers. // In contrast to llama_decode() - this call does not use KV cache.
// For encode-decoder contexts, processes the batch using the encoder.
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
// 0 - success // 0 - success
// < 0 - error. the KV cache state is restored to the state before this call // < 0 - error. the KV cache state is restored to the state before this call
LLAMA_API int32_t llama_encode( LLAMA_API int32_t llama_encode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch); struct llama_batch batch);
// Process a batch of tokens.
// Requires KV cache.
// For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning. // Positive return values does not mean a fatal error, but rather a warning.
// 0 - success // 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
...@@ -1434,6 +1438,37 @@ extern "C" { ...@@ -1434,6 +1438,37 @@ extern "C" {
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
//
// training
//
// function that returns whether or not a given tensor contains trainable parameters
typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata);
// always returns true
LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata);
struct llama_opt_params {
uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters
void * param_filter_ud; // userdata for determining which tensors contain trainable parameters
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
};
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
LLAMA_API void llama_opt_epoch(
struct llama_context * lctx,
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
......
...@@ -253,6 +253,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ ...@@ -253,6 +253,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
std::vector<ggml_backend_buffer_type_t> buft_extra; std::vector<ggml_backend_buffer_type_t> buft_extra;
{ {
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (!cpu_dev) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
...@@ -291,6 +294,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ ...@@ -291,6 +294,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft)); LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (!cpu_dev) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
buft = ggml_backend_dev_buffer_type(cpu_dev); buft = ggml_backend_dev_buffer_type(cpu_dev);
break; break;
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA, "llama" }, { LLM_ARCH_LLAMA, "llama" },
{ LLM_ARCH_MLLAMA, "mllama" },
{ LLM_ARCH_LLAMA4, "llama4" }, { LLM_ARCH_LLAMA4, "llama4" },
{ LLM_ARCH_DECI, "deci" }, { LLM_ARCH_DECI, "deci" },
{ LLM_ARCH_FALCON, "falcon" }, { LLM_ARCH_FALCON, "falcon" },
...@@ -145,7 +144,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { ...@@ -145,7 +144,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" }, { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" },
{ LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
...@@ -275,40 +273,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N ...@@ -275,40 +273,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
}, },
}, },
{
LLM_ARCH_MLLAMA,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_CROSS_ATTN_K_NORM, "blk.%d.cross_attn_k_norm" },
{ LLM_TENSOR_CROSS_ATTN_K_PROJ, "blk.%d.cross_attn_k_proj" },
{ LLM_TENSOR_CROSS_ATTN_O_PROJ, "blk.%d.cross_attn_o_proj" },
{ LLM_TENSOR_CROSS_ATTN_Q_NORM, "blk.%d.cross_attn_q_norm" },
{ LLM_TENSOR_CROSS_ATTN_Q_PROJ, "blk.%d.cross_attn_q_proj" },
{ LLM_TENSOR_CROSS_ATTN_V_PROJ, "blk.%d.cross_attn_v_proj" },
{ LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" },
{ LLM_TENSOR_CROSS_ATTN_MLP_GATE, "blk.%d.cross_attn_mlp_gate" },
},
},
{ {
LLM_ARCH_DECI, LLM_ARCH_DECI,
{ {
...@@ -1737,14 +1701,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { ...@@ -1737,14 +1701,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
// this tensor is loaded for T5, but never used // this tensor is loaded for T5, but never used
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
{LLM_TENSOR_BSKCN_TV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_BSKCN_TV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CROSS_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CROSS_ATTN_K_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CROSS_ATTN_O_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CROSS_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CROSS_ATTN_Q_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CROSS_ATTN_V_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CROSS_ATTN_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CROSS_ATTN_MLP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
{LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
enum llm_arch { enum llm_arch {
LLM_ARCH_LLAMA, LLM_ARCH_LLAMA,
LLM_ARCH_LLAMA4, LLM_ARCH_LLAMA4,
LLM_ARCH_MLLAMA,
LLM_ARCH_DECI, LLM_ARCH_DECI,
LLM_ARCH_FALCON, LLM_ARCH_FALCON,
LLM_ARCH_BAICHUAN, LLM_ARCH_BAICHUAN,
...@@ -149,7 +148,6 @@ enum llm_kv { ...@@ -149,7 +148,6 @@ enum llm_kv {
LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,
LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
...@@ -351,14 +349,6 @@ enum llm_tensor { ...@@ -351,14 +349,6 @@ enum llm_tensor {
LLM_TENSOR_CLS, LLM_TENSOR_CLS,
LLM_TENSOR_CLS_OUT, LLM_TENSOR_CLS_OUT,
LLM_TENSOR_BSKCN_TV, LLM_TENSOR_BSKCN_TV,
LLM_TENSOR_CROSS_ATTN_K_NORM,
LLM_TENSOR_CROSS_ATTN_K_PROJ,
LLM_TENSOR_CROSS_ATTN_O_PROJ,
LLM_TENSOR_CROSS_ATTN_Q_NORM,
LLM_TENSOR_CROSS_ATTN_Q_PROJ,
LLM_TENSOR_CROSS_ATTN_V_PROJ,
LLM_TENSOR_CROSS_ATTN_ATTN_GATE,
LLM_TENSOR_CROSS_ATTN_MLP_GATE,
LLM_TENSOR_CONV1D, LLM_TENSOR_CONV1D,
LLM_TENSOR_CONVNEXT_DW, LLM_TENSOR_CONVNEXT_DW,
LLM_TENSOR_CONVNEXT_NORM, LLM_TENSOR_CONVNEXT_NORM,
......
...@@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { ...@@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
return ubatch; return ubatch;
} }
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
GGML_ASSERT(batch.n_tokens >= 0); GGML_ASSERT(batch.n_tokens >= 0);
this->batch = &batch; this->batch = &batch;
this->n_embd = n_embd; this->n_embd = n_embd;
...@@ -203,6 +203,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim ...@@ -203,6 +203,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
for (size_t i = 0; i < n_tokens; ++i) { for (size_t i = 0; i < n_tokens; ++i) {
ids[i] = i; ids[i] = i;
} }
if (simple_split) { if (simple_split) {
seq.resize(1); seq.resize(1);
llama_sbatch_seq & s = seq[0]; llama_sbatch_seq & s = seq[0];
...@@ -212,6 +213,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim ...@@ -212,6 +213,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
s.length = n_tokens; s.length = n_tokens;
return; return;
} }
std::sort(ids.begin(), ids.end(), std::sort(ids.begin(), ids.end(),
[&batch](size_t a, size_t b) { [&batch](size_t a, size_t b) {
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
...@@ -239,6 +241,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim ...@@ -239,6 +241,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
return n_seq_a > n_seq_b; return n_seq_a > n_seq_b;
} }
); );
// init seq // init seq
llama_sbatch_seq * last_seq = nullptr; llama_sbatch_seq * last_seq = nullptr;
...@@ -262,6 +265,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim ...@@ -262,6 +265,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
seq.push_back(new_seq); seq.push_back(new_seq);
last_seq = &seq.back(); last_seq = &seq.back();
} }
// keep shared prompts first at the end, then sort by length descending. // keep shared prompts first at the end, then sort by length descending.
std::sort(seq.begin(), seq.end(), std::sort(seq.begin(), seq.end(),
[](llama_sbatch_seq & a, llama_sbatch_seq & b) { [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
...@@ -316,7 +320,6 @@ struct llama_batch llama_batch_get_one( ...@@ -316,7 +320,6 @@ struct llama_batch llama_batch_get_one(
/*n_tokens =*/ n_tokens, /*n_tokens =*/ n_tokens,
/*tokens =*/ tokens, /*tokens =*/ tokens,
/*embd =*/ nullptr, /*embd =*/ nullptr,
/*n_embd =*/ 0,
/*pos =*/ nullptr, /*pos =*/ nullptr,
/*n_seq_id =*/ nullptr, /*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr, /*seq_id =*/ nullptr,
...@@ -329,7 +332,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ ...@@ -329,7 +332,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
/*n_tokens =*/ 0, /*n_tokens =*/ 0,
/*tokens =*/ nullptr, /*tokens =*/ nullptr,
/*embd =*/ nullptr, /*embd =*/ nullptr,
/*n_embd =*/ 0,
/*pos =*/ nullptr, /*pos =*/ nullptr,
/*n_seq_id =*/ nullptr, /*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr, /*seq_id =*/ nullptr,
...@@ -338,7 +340,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ ...@@ -338,7 +340,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
if (embd) { if (embd) {
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
batch.n_embd = embd;
} else { } else {
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
} }
......
...@@ -70,7 +70,8 @@ struct llama_sbatch { ...@@ -70,7 +70,8 @@ struct llama_sbatch {
// sequence-wise split // sequence-wise split
llama_ubatch split_seq(size_t n_ubatch); llama_ubatch split_seq(size_t n_ubatch);
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); llama_sbatch() = default;
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
}; };
// temporary allocate memory for the input batch if needed // temporary allocate memory for the input batch if needed
......
...@@ -35,6 +35,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = { ...@@ -35,6 +35,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 }, { "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
{ "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
{ "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
{ "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN },
{ "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, { "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
{ "phi4", LLM_CHAT_TEMPLATE_PHI_4 }, { "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
{ "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 }, { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
...@@ -202,19 +203,20 @@ int32_t llm_chat_apply_template( ...@@ -202,19 +203,20 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "<|im_start|>assistant\n"; ss << "<|im_start|>assistant\n";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) { } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) {
// Official mistral 'v7' template // Official mistral 'v7' template
// See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7 // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
// https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken
const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : "";
for (auto message : chat) { for (auto message : chat) {
std::string role(message->role); std::string role(message->role);
std::string content(message->content); std::string content(message->content);
if (role == "system") { if (role == "system") {
ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]"; ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]";
} else if (role == "user") { } else if (role == "user") {
ss << "[INST] " << content << "[/INST]"; ss << "[INST]" << trailing_space << content << "[/INST]";
} } else {
else { ss << trailing_space << content << "</s>";
ss << " " << content << "</s>";
} }
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
...@@ -447,8 +449,16 @@ int32_t llm_chat_apply_template( ...@@ -447,8 +449,16 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "<|assistant|>"; ss << "<|assistant|>";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4 || tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) { } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4) {
ss << "[gMASK]" << "<sop>"; ss << "[gMASK]" << "<sop>";
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content;
}
if (add_ass) {
ss << "<|assistant|>\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
for (auto message : chat) { for (auto message : chat) {
std::string role(message->role); std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content; ss << "<|" << role << "|>" << "\n" << message->content;
......
...@@ -14,6 +14,7 @@ enum llm_chat_template { ...@@ -14,6 +14,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_MISTRAL_V3, LLM_CHAT_TEMPLATE_MISTRAL_V3,
LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
LLM_CHAT_TEMPLATE_MISTRAL_V7, LLM_CHAT_TEMPLATE_MISTRAL_V7,
LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN,
LLM_CHAT_TEMPLATE_PHI_3, LLM_CHAT_TEMPLATE_PHI_3,
LLM_CHAT_TEMPLATE_PHI_4, LLM_CHAT_TEMPLATE_PHI_4,
LLM_CHAT_TEMPLATE_FALCON_3, LLM_CHAT_TEMPLATE_FALCON_3,
......
...@@ -6,11 +6,9 @@ ...@@ -6,11 +6,9 @@
#include "llama-model.h" #include "llama-model.h"
#include "llama-kv-cache.h" #include "llama-kv-cache.h"
#include <cassert>
#include <cstring> #include <cstring>
#include <stdexcept> #include <stdexcept>
#include <cinttypes> #include <cinttypes>
#include <cmath>
// //
// llama_context // llama_context
...@@ -95,6 +93,7 @@ llama_context::llama_context( ...@@ -95,6 +93,7 @@ llama_context::llama_context(
} }
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
cparams.op_offload = params.op_offload;
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
...@@ -118,8 +117,6 @@ llama_context::llama_context( ...@@ -118,8 +117,6 @@ llama_context::llama_context(
__func__, n_ctx_per_seq, hparams.n_ctx_train); __func__, n_ctx_per_seq, hparams.n_ctx_train);
} }
logits_all = params.logits_all;
if (!hparams.vocab_only) { if (!hparams.vocab_only) {
// GPU backends // GPU backends
for (auto * dev : model.devices) { for (auto * dev : model.devices) {
...@@ -177,44 +174,13 @@ llama_context::llama_context( ...@@ -177,44 +174,13 @@ llama_context::llama_context(
} }
// init the memory module // init the memory module
// TODO: for now, always create a unified KV cache
if (!hparams.vocab_only) { if (!hparams.vocab_only) {
kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory())); llama_memory_params params_mem = {
/*.type_k =*/ params.type_k,
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); /*.type_v =*/ params.type_v,
};
cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
uint32_t kv_size = cparams.n_ctx;
ggml_type type_k = params.type_k;
ggml_type type_v = params.type_v;
if (llama_model_is_recurrent(&model)) {
// Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
}
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
throw std::runtime_error("failed to initialize self-attention cache");
}
{
const size_t memory_size_k = kv_self->size_k_bytes();
const size_t memory_size_v = kv_self->size_v_bytes();
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, memory.reset(model.create_memory(params_mem, cparams));
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
} }
// init backends // init backends
...@@ -278,7 +244,7 @@ llama_context::llama_context( ...@@ -278,7 +244,7 @@ llama_context::llama_context(
} }
} }
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel)); sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload));
if (pipeline_parallel) { if (pipeline_parallel) {
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
...@@ -286,7 +252,7 @@ llama_context::llama_context( ...@@ -286,7 +252,7 @@ llama_context::llama_context(
} }
// reserve worst-case graph // reserve worst-case graph
if (!hparams.vocab_only) { if (!hparams.vocab_only && memory) {
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
...@@ -305,7 +271,9 @@ llama_context::llama_context( ...@@ -305,7 +271,9 @@ llama_context::llama_context(
int n_nodes_tg = -1; int n_nodes_tg = -1;
// simulate full KV cache // simulate full KV cache
kv_self->n = kv_self->size; llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->set_full();
cross.v_embd.clear(); cross.v_embd.clear();
...@@ -391,7 +359,9 @@ llama_context::llama_context( ...@@ -391,7 +359,9 @@ llama_context::llama_context(
} }
} }
llama_context::~llama_context() = default; llama_context::~llama_context() {
ggml_opt_free(opt_ctx);
}
void llama_context::synchronize() { void llama_context::synchronize() {
ggml_backend_sched_synchronize(sched.get()); ggml_backend_sched_synchronize(sched.get());
...@@ -427,6 +397,18 @@ const llama_model & llama_context::get_model() const { ...@@ -427,6 +397,18 @@ const llama_model & llama_context::get_model() const {
return model; return model;
} }
const llama_cparams & llama_context::get_cparams() const {
return cparams;
}
ggml_backend_sched_t llama_context::get_sched() const {
return sched.get();
}
ggml_context * llama_context::get_ctx_compute() const {
return ctx_compute.get();
}
uint32_t llama_context::n_ctx() const { uint32_t llama_context::n_ctx() const {
return cparams.n_ctx; return cparams.n_ctx;
} }
...@@ -456,318 +438,44 @@ uint32_t llama_context::n_threads_batch() const { ...@@ -456,318 +438,44 @@ uint32_t llama_context::n_threads_batch() const {
} }
llama_kv_cache * llama_context::get_kv_self() { llama_kv_cache * llama_context::get_kv_self() {
return kv_self.get(); llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
return kv_self;
} }
const llama_kv_cache * llama_context::get_kv_self() const { const llama_kv_cache * llama_context::get_kv_self() const {
return kv_self.get(); llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
} return kv_self;
ggml_tensor * llama_context::build_rope_shift(
ggml_context * ctx0,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale) const {
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
const auto & hparams = model.hparams;
const auto & n_rot = hparams.n_rot;
const auto & rope_type = hparams.rope_type;
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
ggml_tensor * tmp;
if (ggml_is_quantized(cur->type)) {
// dequantize to f32 -> RoPE -> quantize back
tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
tmp = ggml_rope_ext(ctx0, tmp,
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
tmp = ggml_cpy(ctx0, tmp, cur);
} else {
// we rotate only the first n_rot dimensions
tmp = ggml_rope_ext_inplace(ctx0, cur,
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
}
return tmp;
}
class llm_graph_input_k_shift : public llm_graph_input_i {
public:
llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_k_shift() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * k_shift; // I32 [kv_size]
const llama_kv_cache_unified * kv_self;
};
void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
if (k_shift) {
assert(ggml_backend_buffer_is_host(k_shift->buffer));
int32_t * data = (int32_t *) k_shift->data;
for (uint32_t i = 0; i < kv_self->size; ++i) {
data[i] = kv_self->cells[i].delta;
}
}
}
llm_graph_result_ptr llama_context::build_kv_self_shift(
ggml_context * ctx0,
ggml_cgraph * gf) const {
auto res = std::make_unique<llm_graph_result>();
const auto & hparams = model.hparams;
const auto & n_layer = hparams.n_layer;
const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v;
//GGML_ASSERT(kv_self->size == n_ctx);
auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
ggml_set_input(inp->k_shift);
for (uint32_t il = 0; il < n_layer; ++il) {
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const bool is_swa = hparams.is_swa(il);
// note: the swa rope params could become part of the cparams in the future
// if we decide to make them configurable, like the non-sliding ones
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
ggml_tensor * k =
ggml_view_3d(ctx0, kv_self->k_l[il],
n_embd_head_k, n_head_kv, kv_self->size,
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
0);
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
ggml_build_forward_expand(gf, cur);
}
res->add_input(std::move(inp));
return res;
}
llm_graph_result_ptr llama_context::build_kv_self_defrag(
ggml_context * ctx0,
ggml_cgraph * gf,
const std::vector<struct llama_kv_defrag_move> & moves) const {
auto res = std::make_unique<llm_graph_result>();
const auto & hparams = model.hparams;
#if 0
// CPU defrag
//
// TODO: optimizations are possible:
// - multiple threads
// - avoid copying to the host memory when already there
//
// likely not worth the effort, as we have ggml_graph based defrag
//
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const uint32_t kv_size = size;
std::vector<uint8_t> buf_k;
std::vector<uint8_t> buf_v;
for (uint32_t il = 0; il < n_layer; ++il) {
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
const size_t v_size_el = ggml_type_size(v_l[il]->type);
const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
buf_k.resize(k_size);
buf_v.resize(v_size);
ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
// batch move [i, i+nm) to [id, id+nm)
// note: cells can move only to a lower index
for (uint32_t i = 0; i < n_kv; ++i) {
const uint32_t id = ids[i];
if (i == id || id == n_kv) {
continue;
}
uint32_t nm = 1;
while (i + nm < n_kv && ids[i + nm] == id + nm) {
nm++;
}
// move keys
{
const int64_t os = i*k_size_row;
const int64_t od = id*k_size_row;
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
}
// move values (note: they are transposed)
{
const int64_t os = i;
const int64_t od = id;
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
}
}
i += nm - 1;
}
ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
}
#else
for (const auto & move : moves) {
for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
n_embd_k_gqa, move.len,
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*move.src));
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
n_embd_k_gqa, move.len,
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*move.dst));
ggml_tensor * view_v_src;
ggml_tensor * view_v_dst;
if (cparams.flash_attn) {
// NOTE: the V cache is not transposed when using flash attention
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
n_embd_v_gqa, move.len,
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*move.src));
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
n_embd_v_gqa, move.len,
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*move.dst));
} else {
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
move.len, n_embd_v_gqa,
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
ggml_row_size(kv_self->v_l[il]->type, move.src));
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
move.len, n_embd_v_gqa,
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
ggml_row_size(kv_self->v_l[il]->type, move.dst));
}
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
}
}
#endif
return res;
} }
void llama_context::kv_self_update() { void llama_context::kv_self_update() {
auto & kv = kv_self; bool need_reserve = false;
if (kv->has_shift) {
if (!kv->get_can_shift()) {
GGML_ABORT("The current context does not support K-shift");
}
LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
// apply K-shift if needed
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
ggml_backend_sched_reset(sched.get());
auto * gf = graph_init();
auto res = build_kv_self_shift(ctx_compute.get(), gf);
ggml_backend_sched_alloc_graph(sched.get(), gf); llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
res->set_inputs(nullptr); need_reserve = kv_self->update(*this);
graph_compute(gf, false); // reserve a worst case graph if needed
} if (need_reserve) {
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
{ // build worst-case graph
kv->has_shift = false; uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
for (uint32_t i = 0; i < kv->size; ++i) { // simulate full KV cache
kv->cells[i].delta = 0; kv_self->set_full();
}
}
}
// defragment the KV cache if needed llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
if (kv->do_defrag) { llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
const uint32_t n_max_nodes = graph_max_nodes();
const uint32_t max_moves = (n_max_nodes - 2*model.hparams.n_layer)/(6*model.hparams.n_layer);
if (!kv->defrag_prepare(n_max_nodes)) {
LLAMA_LOG_ERROR("%s: failed to prepare defragmentation\n", __func__);
return;
}
for (std::size_t i = 0; i < kv_self->defrag_info.moves.size(); i += max_moves) { auto * gf = graph_init();
std::vector<struct llama_kv_defrag_move> chunk; graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
auto end = std::min(i + max_moves, kv_self->defrag_info.moves.size());
chunk.assign(kv_self->defrag_info.moves.begin() + i, kv_self->defrag_info.moves.begin() + end);
ggml_backend_sched_reset(sched.get()); // initialize scheduler with the worst-case graph
auto * gf = graph_init(); ggml_backend_sched_reset(sched.get());
auto res = build_kv_self_defrag(ctx_compute.get(), gf, chunk); if (!ggml_backend_sched_reserve(sched.get(), gf)) {
ggml_backend_sched_alloc_graph(sched.get(), gf); LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
res->set_inputs(nullptr);
graph_compute(gf, false);
} }
kv->do_defrag = false;
} }
} }
...@@ -776,9 +484,6 @@ enum llama_pooling_type llama_context::pooling_type() const { ...@@ -776,9 +484,6 @@ enum llama_pooling_type llama_context::pooling_type() const {
} }
float * llama_context::get_logits() { float * llama_context::get_logits() {
// reorder logits for backward compatibility
output_reorder();
return logits; return logits;
} }
...@@ -809,7 +514,7 @@ float * llama_context::get_logits_ith(int32_t i) { ...@@ -809,7 +514,7 @@ float * llama_context::get_logits_ith(int32_t i) {
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs)); throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
} }
return logits + j*model.hparams.n_vocab; return logits + j*model.vocab.n_tokens();
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG #ifndef NDEBUG
...@@ -821,9 +526,6 @@ float * llama_context::get_logits_ith(int32_t i) { ...@@ -821,9 +526,6 @@ float * llama_context::get_logits_ith(int32_t i) {
} }
float * llama_context::get_embeddings() { float * llama_context::get_embeddings() {
// reorder embeddings for backward compatibility
output_reorder();
return embd; return embd;
} }
...@@ -930,10 +632,6 @@ void llama_context::set_warmup(bool value) { ...@@ -930,10 +632,6 @@ void llama_context::set_warmup(bool value) {
cparams.warmup = value; cparams.warmup = value;
} }
void llama_context::set_cross_attn(bool value) {
cparams.cross_attn = value;
}
void llama_context::set_adapter_lora( void llama_context::set_adapter_lora(
llama_adapter_lora * adapter, llama_adapter_lora * adapter,
float scale) { float scale) {
...@@ -979,8 +677,8 @@ int llama_context::encode(llama_batch & inp_batch) { ...@@ -979,8 +677,8 @@ int llama_context::encode(llama_batch & inp_batch) {
} }
// temporary allocate memory for the input batch if needed // temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences // note: during encode, we always pass the full sequence starting from pos = 0
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
const llama_batch & batch = batch_allocr.batch; const llama_batch & batch = batch_allocr.batch;
const int32_t n_tokens = batch.n_tokens; const int32_t n_tokens = batch.n_tokens;
...@@ -1005,11 +703,13 @@ int llama_context::encode(llama_batch & inp_batch) { ...@@ -1005,11 +703,13 @@ int llama_context::encode(llama_batch & inp_batch) {
t_compute_start_us = ggml_time_us(); t_compute_start_us = ggml_time_us();
} }
embd_seq.clear();
n_queued_tokens += n_tokens; n_queued_tokens += n_tokens;
const int64_t n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd;
sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true); llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
const llama_ubatch ubatch = sbatch.split_simple(n_tokens); const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
...@@ -1066,12 +766,12 @@ int llama_context::encode(llama_batch & inp_batch) { ...@@ -1066,12 +766,12 @@ int llama_context::encode(llama_batch & inp_batch) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
GGML_ASSERT(backend_embd != nullptr); GGML_ASSERT(backend_embd != nullptr);
GGML_ASSERT(embd != nullptr);
switch (cparams.pooling_type) { switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE: case LLAMA_POOLING_TYPE_NONE:
{ {
// extract token embeddings // extract token embeddings
GGML_ASSERT(embd != nullptr);
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float)); ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
} break; } break;
...@@ -1096,11 +796,18 @@ int llama_context::encode(llama_batch & inp_batch) { ...@@ -1096,11 +796,18 @@ int llama_context::encode(llama_batch & inp_batch) {
} break; } break;
case LLAMA_POOLING_TYPE_RANK: case LLAMA_POOLING_TYPE_RANK:
{ {
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to // extract the rerank score - a single float per sequence
// wait for an encoder model that requires this pooling type in order to test it auto & embd_seq_out = embd_seq;
// https://github.com/ggerganov/llama.cpp/pull/9510
GGML_ABORT("RANK pooling not implemented yet"); for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
} const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(1);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED: case LLAMA_POOLING_TYPE_UNSPECIFIED:
{ {
GGML_ABORT("unknown pooling type"); GGML_ABORT("unknown pooling type");
...@@ -1138,25 +845,33 @@ int llama_context::encode(llama_batch & inp_batch) { ...@@ -1138,25 +845,33 @@ int llama_context::encode(llama_batch & inp_batch) {
} }
int llama_context::decode(llama_batch & inp_batch) { int llama_context::decode(llama_batch & inp_batch) {
if (!memory) {
LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
return encode(inp_batch);
}
if (inp_batch.n_tokens == 0) { if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1; return -1;
} }
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
// temporary allocate memory for the input batch if needed // temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
const llama_batch & batch = batch_allocr.batch; const llama_batch & batch = batch_allocr.batch;
const auto & vocab = model.vocab;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const int32_t n_vocab = hparams.n_vocab; const int32_t n_vocab = vocab.n_tokens();
const int64_t n_tokens_all = batch.n_tokens; const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd;
llama_kv_cache_guard kv_guard(kv_self.get()); llama_kv_cache_guard kv_guard(kv_self);
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
...@@ -1190,18 +905,14 @@ int llama_context::decode(llama_batch & inp_batch) { ...@@ -1190,18 +905,14 @@ int llama_context::decode(llama_batch & inp_batch) {
for (uint32_t i = 0; i < n_tokens_all; ++i) { for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs_all += batch.logits[i] != 0; n_outputs_all += batch.logits[i] != 0;
} }
} else if (logits_all || embd_pooled) { } else if (embd_pooled) {
n_outputs_all = n_tokens_all; n_outputs_all = n_tokens_all;
} else { } else {
// keep last output only // keep last output only
n_outputs_all = 1; n_outputs_all = 1;
} }
const bool logits_all = n_outputs_all == n_tokens_all; llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
sbatch.from_batch(batch, batch.n_embd,
/* simple_split */ !kv_self->recurrent,
/* logits_all */ logits_all);
// reserve output buffer // reserve output buffer
if (output_reserve(n_outputs_all) < n_outputs_all) { if (output_reserve(n_outputs_all) < n_outputs_all) {
...@@ -1215,22 +926,7 @@ int llama_context::decode(llama_batch & inp_batch) { ...@@ -1215,22 +926,7 @@ int llama_context::decode(llama_batch & inp_batch) {
int64_t n_outputs_prev = 0; int64_t n_outputs_prev = 0;
while (sbatch.n_tokens > 0) { while (sbatch.n_tokens > 0) {
llama_ubatch ubatch = llama_ubatch(); llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
const auto & n_ubatch = cparams.n_ubatch;
if (kv_self->recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = sbatch.split_seq(cparams.n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = sbatch.split_equal(cparams.n_ubatch);
}
} else {
ubatch = sbatch.split_simple(n_ubatch);
}
// count the outputs in this u_batch // count the outputs in this u_batch
{ {
...@@ -1250,27 +946,15 @@ int llama_context::decode(llama_batch & inp_batch) { ...@@ -1250,27 +946,15 @@ int llama_context::decode(llama_batch & inp_batch) {
} }
// find KV slot // find KV slot
{ if (!kv_self->find_slot(ubatch)) {
kv_self->defrag_sched(-1.0f);
kv_self->update(*this);
if (!kv_self->find_slot(ubatch)) { if (!kv_self->find_slot(ubatch)) {
kv_self->defrag(); LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
kv_self_update(); return 1;
if (!kv_self->find_slot(ubatch)) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
return 1;
}
}
if (!kv_self->recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = kv_self->get_padding(cparams);
kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
} }
} }
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
ggml_backend_sched_reset(sched.get()); ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
...@@ -1384,43 +1068,68 @@ int llama_context::decode(llama_batch & inp_batch) { ...@@ -1384,43 +1068,68 @@ int llama_context::decode(llama_batch & inp_batch) {
// finalize the batch processing // finalize the batch processing
kv_guard.commit(); kv_guard.commit();
// set to total number of outputs in the batch, for use in llama_get_logits_ith
n_outputs = n_outputs_all;
// set output mappings // set output mappings
{ {
bool sorted_output = true; bool sorted_output = true;
GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all); auto & out_ids = sbatch.out_ids;
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
for (int64_t i = 0; i < n_outputs_all; ++i) { for (int64_t i = 0; i < n_outputs_all; ++i) {
int64_t out_id = sbatch.out_ids[i]; int64_t out_id = out_ids[i];
output_ids[out_id] = i; output_ids[out_id] = i;
if (out_id != i) { if (out_id != i) {
sorted_output = false; sorted_output = false;
} }
} }
if (sorted_output) { // make the outputs have the same order they had in the user-provided batch
sbatch.out_ids.clear(); // note: this is mostly relevant for recurrent models atm
if (!sorted_output) {
const uint32_t n_vocab = model.vocab.n_tokens();
const uint32_t n_embd = model.hparams.n_embd;
GGML_ASSERT((size_t) n_outputs == out_ids.size());
// TODO: is there something more efficient which also minimizes swaps?
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
for (int32_t i = 0; i < n_outputs - 1; ++i) {
int32_t j_min = i;
for (int32_t j = i + 1; j < n_outputs; ++j) {
if (out_ids[j] < out_ids[j_min]) {
j_min = j;
}
}
if (j_min == i) { continue; }
std::swap(out_ids[i], out_ids[j_min]);
if (logits_size > 0) {
for (uint32_t k = 0; k < n_vocab; k++) {
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
}
}
if (embd_size > 0) {
for (uint32_t k = 0; k < n_embd; k++) {
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
}
}
}
std::fill(output_ids.begin(), output_ids.end(), -1);
for (int32_t i = 0; i < n_outputs; ++i) {
output_ids[out_ids[i]] = i;
}
} }
} }
// set to total number of outputs in the batch, for use in llama_get_logits_ith
n_outputs = n_outputs_all;
// wait for the computation to finish (automatically done when obtaining the model output) // wait for the computation to finish (automatically done when obtaining the model output)
//synchronize(); //synchronize();
// decide if we need to defrag the kv cache // decide if we need to defrag the kv cache
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) { if (cparams.defrag_thold > 0.0f) {
// - do not defrag small contexts (i.e. < 2048 tokens) kv_self->defrag_sched(cparams.defrag_thold);
// - count the padding towards the number of used tokens
const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
kv_self->defrag();
}
} }
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
...@@ -1436,11 +1145,12 @@ int llama_context::decode(llama_batch & inp_batch) { ...@@ -1436,11 +1145,12 @@ int llama_context::decode(llama_batch & inp_batch) {
int32_t llama_context::output_reserve(int32_t n_outputs) { int32_t llama_context::output_reserve(int32_t n_outputs) {
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const auto & vocab = model.vocab;
const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max()); const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
const auto n_batch = cparams.n_batch; const auto n_batch = cparams.n_batch;
const auto n_vocab = hparams.n_vocab; const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd; const auto n_embd = hparams.n_embd;
// TODO: use a per-batch flag for logits presence instead // TODO: use a per-batch flag for logits presence instead
...@@ -1505,44 +1215,6 @@ int32_t llama_context::output_reserve(int32_t n_outputs) { ...@@ -1505,44 +1215,6 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
return n_outputs_max; return n_outputs_max;
} }
void llama_context::output_reorder() {
auto & out_ids = sbatch.out_ids;
if (!out_ids.empty()) {
const uint32_t n_vocab = model.hparams.n_vocab;
const uint32_t n_embd = model.hparams.n_embd;
GGML_ASSERT((size_t) n_outputs == out_ids.size());
// TODO: is there something more efficient which also minimizes swaps?
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
for (int32_t i = 0; i < n_outputs - 1; ++i) {
int32_t j_min = i;
for (int32_t j = i + 1; j < n_outputs; ++j) {
if (out_ids[j] < out_ids[j_min]) {
j_min = j;
}
}
if (j_min == i) { continue; }
std::swap(out_ids[i], out_ids[j_min]);
if (logits_size > 0) {
for (uint32_t k = 0; k < n_vocab; k++) {
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
}
}
if (embd_size > 0) {
for (uint32_t k = 0; k < n_embd; k++) {
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
}
}
}
std::fill(output_ids.begin(), output_ids.end(), -1);
for (int32_t i = 0; i < n_outputs; ++i) {
output_ids[out_ids[i]] = i;
}
out_ids.clear();
}
}
// //
// graph // graph
// //
...@@ -1579,7 +1251,7 @@ llm_graph_result_ptr llama_context::graph_build( ...@@ -1579,7 +1251,7 @@ llm_graph_result_ptr llama_context::graph_build(
/*.backend_cpu =*/ backend_cpu, /*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec, /*.cvec =*/ &cvec,
/*.loras =*/ &loras, /*.loras =*/ &loras,
/*.memory =*/ kv_self.get(), /*.memory =*/ memory.get(),
/*.cross =*/ &cross, /*.cross =*/ &cross,
/*.n_outputs =*/ n_outputs, /*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(), /*.cb =*/ graph_get_cb(),
...@@ -1983,8 +1655,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { ...@@ -1983,8 +1655,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
{ {
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
output_reorder();
const auto n_outputs = this->n_outputs; const auto n_outputs = this->n_outputs;
const auto & output_ids = this->output_ids; const auto & output_ids = this->output_ids;
...@@ -2015,7 +1685,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { ...@@ -2015,7 +1685,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
{ {
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.hparams.n_vocab); const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
io.write(&logits_size, sizeof(logits_size)); io.write(&logits_size, sizeof(logits_size));
...@@ -2038,6 +1708,8 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { ...@@ -2038,6 +1708,8 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
} }
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_write(io); kv_self->state_write(io);
return io.n_bytes(); return io.n_bytes();
...@@ -2121,8 +1793,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { ...@@ -2121,8 +1793,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
} }
} }
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); if (memory) {
kv_self->state_read(io); LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_read(io);
}
return io.n_bytes(); return io.n_bytes();
} }
...@@ -2130,7 +1807,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { ...@@ -2130,7 +1807,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
GGML_UNUSED(seq_id); GGML_UNUSED(seq_id);
kv_self->state_write(io, seq_id); if (memory) {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_write(io, seq_id);
}
return io.n_bytes(); return io.n_bytes();
} }
...@@ -2138,7 +1819,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s ...@@ -2138,7 +1819,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
GGML_UNUSED(seq_id); GGML_UNUSED(seq_id);
kv_self->state_read(io, seq_id); if (memory) {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_read(io, seq_id);
}
return io.n_bytes(); return io.n_bytes();
} }
...@@ -2166,6 +1851,218 @@ void llama_context::perf_reset() { ...@@ -2166,6 +1851,218 @@ void llama_context::perf_reset() {
t_p_eval_us = n_p_eval = 0; t_p_eval_us = n_p_eval = 0;
} }
//
// training
//
static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
if (!tensor || tensor->type != GGML_TYPE_F32) {
return;
}
if (!param_filter(tensor, userdata)) {
return;
}
if (strcmp(tensor->name, "token_embd.weight") == 0) {
return; // FIXME
}
if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
return; // FIXME
}
ggml_set_param(tensor);
}
void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
GGML_ASSERT(!opt_ctx);
model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train);
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
GGML_ASSERT(n_batch % n_ubatch == 0);
ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
opt_params.opt_period = n_batch / n_ubatch;
opt_params.get_opt_pars = lopt_params.get_opt_pars;
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
opt_ctx = ggml_opt_init(opt_params);
llama_opt_param_filter param_filter = lopt_params.param_filter;
void * param_filter_ud = lopt_params.param_filter_ud;
//llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
llama_set_param(model->type_embd, param_filter, param_filter_ud);
llama_set_param(model->pos_embd, param_filter, param_filter_ud);
llama_set_param(model->tok_norm, param_filter, param_filter_ud);
llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
llama_set_param(model->output_norm, param_filter, param_filter_ud);
llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
llama_set_param(model->output, param_filter, param_filter_ud);
llama_set_param(model->output_b, param_filter, param_filter_ud);
llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
llama_set_param(model->cls, param_filter, param_filter_ud);
llama_set_param(model->cls_b, param_filter, param_filter_ud);
llama_set_param(model->cls_out, param_filter, param_filter_ud);
llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
for (struct llama_layer & layer : model->layers) {
for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
}
}
}
void llama_context::opt_epoch_iter(
ggml_opt_dataset_t dataset,
ggml_opt_result_t result,
const std::vector<llama_token> & tokens,
const std::vector<llama_token> & labels_sparse,
llama_batch & batch,
ggml_opt_epoch_callback callback,
bool train,
int64_t idata_in_loop,
int64_t ndata_in_loop,
int64_t t_loop_start) {
GGML_ASSERT(opt_ctx);
const uint32_t n_ctx = llama_model_n_ctx_train(&model);
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->clear();
llama_kv_cache_guard kv_guard(kv_self);
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
batch.n_tokens = n_batch;
for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
batch.pos [pos_batch] = pos_ctx + pos_batch;
batch.n_seq_id[pos_batch] = 1;
batch.seq_id [pos_batch][0] = 0;
batch.logits [pos_batch] = true;
}
const auto n_tokens_all = batch.n_tokens;
n_queued_tokens += n_tokens_all;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
embd_seq.clear();
int64_t n_outputs_all = n_tokens_all;
llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
// reserve output buffer
if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
GGML_ABORT("TODO: handle this error");
};
for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
n_outputs = ubatch.n_tokens;
// TODO: not sure if this is needed
if (!kv_self->find_slot(ubatch)) {
kv_self->defrag_sched(-1.0f);
kv_self->update(*this);
if (!kv_self->find_slot(ubatch)) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
GGML_ABORT("TODO: handle this error");
}
}
auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
struct ggml_context * ctx_compute_opt;
{
const size_t size_gf = ggml_graph_size(gf);
const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
struct ggml_init_params params = {
/*.mem_size =*/ size_meta,
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
ctx_compute_opt = ggml_init(params);
}
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
ggml_opt_alloc(opt_ctx, train);
res->set_inputs(&ubatch);
{
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
GGML_ASSERT(labels->ne[1] == n_ubatch);
ggml_set_zero(labels);
const float onef = 1.0f;
for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
}
}
ggml_opt_eval(opt_ctx, result);
if (callback) {
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
}
ggml_free(ctx_compute_opt);
}
}
kv_guard.commit();
}
void llama_context::opt_epoch(
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval) {
const uint32_t n_ctx = this->n_ctx();
const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
const int64_t ndata = ggml_opt_dataset_ndata(dataset);
GGML_ASSERT(idata_split >= 0);
GGML_ASSERT(idata_split <= ndata);
const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
std::vector<llama_token> tokens(n_ctx);
std::vector<llama_token> labels_sparse(n_ctx);
int64_t idata = 0;
int64_t t_loop_start = ggml_time_us();
int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
for (; idata < idata_split; ++idata) {
constexpr bool train = true;
const int64_t idata_in_loop = idata*ubatch_per_ctx;
ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
}
t_loop_start = ggml_time_us();
ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
for (; idata < ndata; ++idata) {
constexpr bool train = false;
const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
}
llama_batch_free(batch);
}
// //
// interface implementation // interface implementation
// //
...@@ -2193,14 +2090,13 @@ llama_context_params llama_context_default_params() { ...@@ -2193,14 +2090,13 @@ llama_context_params llama_context_default_params() {
/*.cb_eval_user_data =*/ nullptr, /*.cb_eval_user_data =*/ nullptr,
/*.type_k =*/ GGML_TYPE_F16, /*.type_k =*/ GGML_TYPE_F16,
/*.type_v =*/ GGML_TYPE_F16, /*.type_v =*/ GGML_TYPE_F16,
/*.logits_all =*/ false, /*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
/*.embeddings =*/ false, /*.embeddings =*/ false,
/*.offload_kqv =*/ true, /*.offload_kqv =*/ true,
/*.flash_attn =*/ false, /*.flash_attn =*/ false,
/*.no_perf =*/ true, /*.no_perf =*/ true,
/*.cross_attn =*/ false, /*.op_offload =*/ true,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
}; };
return result; return result;
...@@ -2326,10 +2222,6 @@ void llama_set_warmup(llama_context * ctx, bool warmup) { ...@@ -2326,10 +2222,6 @@ void llama_set_warmup(llama_context * ctx, bool warmup) {
ctx->set_warmup(warmup); ctx->set_warmup(warmup);
} }
void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) {
ctx->set_cross_attn(cross_attention);
}
void llama_synchronize(llama_context * ctx) { void llama_synchronize(llama_context * ctx) {
ctx->synchronize(); ctx->synchronize();
} }
...@@ -2498,7 +2390,7 @@ void llama_kv_cache_seq_cp( ...@@ -2498,7 +2390,7 @@ void llama_kv_cache_seq_cp(
llama_seq_id seq_id_dst, llama_seq_id seq_id_dst,
llama_pos p0, llama_pos p0,
llama_pos p1) { llama_pos p1) {
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
} }
void llama_kv_self_seq_cp( void llama_kv_self_seq_cp(
...@@ -2512,14 +2404,14 @@ void llama_kv_self_seq_cp( ...@@ -2512,14 +2404,14 @@ void llama_kv_self_seq_cp(
return; return;
} }
return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
} }
// deprecated // deprecated
void llama_kv_cache_seq_keep( void llama_kv_cache_seq_keep(
llama_context * ctx, llama_context * ctx,
llama_seq_id seq_id) { llama_seq_id seq_id) {
return llama_kv_self_seq_keep(ctx, seq_id); llama_kv_self_seq_keep(ctx, seq_id);
} }
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
...@@ -2528,7 +2420,7 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { ...@@ -2528,7 +2420,7 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
return; return;
} }
return kv->seq_keep(seq_id); kv->seq_keep(seq_id);
} }
// deprecated // deprecated
...@@ -2538,7 +2430,7 @@ void llama_kv_cache_seq_add( ...@@ -2538,7 +2430,7 @@ void llama_kv_cache_seq_add(
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
llama_pos delta) { llama_pos delta) {
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta); llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
} }
void llama_kv_self_seq_add( void llama_kv_self_seq_add(
...@@ -2552,7 +2444,7 @@ void llama_kv_self_seq_add( ...@@ -2552,7 +2444,7 @@ void llama_kv_self_seq_add(
return; return;
} }
return kv->seq_add(seq_id, p0, p1, delta); kv->seq_add(seq_id, p0, p1, delta);
} }
// deprecated // deprecated
...@@ -2562,7 +2454,7 @@ void llama_kv_cache_seq_div( ...@@ -2562,7 +2454,7 @@ void llama_kv_cache_seq_div(
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
int d) { int d) {
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d); llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
} }
void llama_kv_self_seq_div( void llama_kv_self_seq_div(
...@@ -2576,7 +2468,7 @@ void llama_kv_self_seq_div( ...@@ -2576,7 +2468,7 @@ void llama_kv_self_seq_div(
return; return;
} }
return kv->seq_div(seq_id, p0, p1, d); kv->seq_div(seq_id, p0, p1, d);
} }
// deprecated // deprecated
...@@ -2595,7 +2487,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { ...@@ -2595,7 +2487,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
// deprecated // deprecated
void llama_kv_cache_defrag(llama_context * ctx) { void llama_kv_cache_defrag(llama_context * ctx) {
return llama_kv_self_defrag(ctx); llama_kv_self_defrag(ctx);
} }
void llama_kv_self_defrag(llama_context * ctx) { void llama_kv_self_defrag(llama_context * ctx) {
...@@ -2604,7 +2496,8 @@ void llama_kv_self_defrag(llama_context * ctx) { ...@@ -2604,7 +2496,8 @@ void llama_kv_self_defrag(llama_context * ctx) {
return; return;
} }
return kv->defrag(); // force defrag
kv->defrag_sched(-1.0f);
} }
// deprecated // deprecated
...@@ -2788,3 +2681,34 @@ void llama_perf_context_print(const llama_context * ctx) { ...@@ -2788,3 +2681,34 @@ void llama_perf_context_print(const llama_context * ctx) {
void llama_perf_context_reset(llama_context * ctx) { void llama_perf_context_reset(llama_context * ctx) {
ctx->perf_reset(); ctx->perf_reset();
} }
//
// training
//
bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) {
GGML_UNUSED(tensor);
GGML_UNUSED(userdata);
return true;
}
void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
ctx->opt_init(model, lopt_params);
}
void llama_opt_epoch(
struct llama_context * ctx,
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval) {
ctx->opt_epoch(
dataset,
result_train,
result_eval,
idata_split,
callback_train,
callback_eval);
}
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "llama-kv-cache.h" #include "llama-kv-cache.h"
#include "ggml-cpp.h" #include "ggml-cpp.h"
#include "ggml-opt.h"
#include <map> #include <map>
#include <vector> #include <vector>
...@@ -28,7 +29,12 @@ struct llama_context { ...@@ -28,7 +29,12 @@ struct llama_context {
void synchronize(); void synchronize();
const llama_model & get_model() const; const llama_model & get_model() const;
const llama_cparams & get_cparams() const;
ggml_backend_sched_t get_sched() const;
ggml_context * get_ctx_compute() const;
uint32_t n_ctx() const; uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const; uint32_t n_ctx_per_seq() const;
...@@ -66,7 +72,6 @@ struct llama_context { ...@@ -66,7 +72,6 @@ struct llama_context {
void set_embeddings (bool value); void set_embeddings (bool value);
void set_causal_attn(bool value); void set_causal_attn(bool value);
void set_warmup(bool value); void set_warmup(bool value);
void set_cross_attn(bool value);
void set_adapter_lora( void set_adapter_lora(
llama_adapter_lora * adapter, llama_adapter_lora * adapter,
...@@ -130,6 +135,32 @@ struct llama_context { ...@@ -130,6 +135,32 @@ struct llama_context {
llama_perf_context_data perf_get_data() const; llama_perf_context_data perf_get_data() const;
void perf_reset(); void perf_reset();
//
// training
//
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
void opt_epoch(
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
void opt_epoch_iter(
ggml_opt_dataset_t dataset,
ggml_opt_result_t result,
const std::vector<llama_token> & tokens,
const std::vector<llama_token> & labels_sparse,
llama_batch & batch,
ggml_opt_epoch_callback callback,
bool train,
int64_t idata_in_loop,
int64_t ndata_in_loop,
int64_t t_loop_start);
private: private:
// //
// output // output
...@@ -139,50 +170,30 @@ private: ...@@ -139,50 +170,30 @@ private:
// Returns max number of outputs for which space was reserved. // Returns max number of outputs for which space was reserved.
int32_t output_reserve(int32_t n_outputs); int32_t output_reserve(int32_t n_outputs);
// make the outputs have the same order they had in the user-provided batch
// TODO: maybe remove this
void output_reorder();
// //
// graph // graph
// //
public:
int32_t graph_max_nodes() const; int32_t graph_max_nodes() const;
// zero-out inputs and create the ctx_compute for the compute graph // zero-out inputs and create the ctx_compute for the compute graph
ggml_cgraph * graph_init(); ggml_cgraph * graph_init();
// returns the result of ggml_backend_sched_graph_compute_async execution
ggml_status graph_compute(
ggml_cgraph * gf,
bool batched);
private:
llm_graph_result_ptr graph_build( llm_graph_result_ptr graph_build(
ggml_context * ctx, ggml_context * ctx,
ggml_cgraph * gf, ggml_cgraph * gf,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
llm_graph_type gtype); llm_graph_type gtype);
// returns the result of ggml_backend_sched_graph_compute_async execution
ggml_status graph_compute(
ggml_cgraph * gf,
bool batched);
llm_graph_cb graph_get_cb() const; llm_graph_cb graph_get_cb() const;
// used by kv_self_update()
ggml_tensor * build_rope_shift(
ggml_context * ctx0,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale) const;
llm_graph_result_ptr build_kv_self_shift(
ggml_context * ctx0,
ggml_cgraph * gf) const;
llm_graph_result_ptr build_kv_self_defrag(
ggml_context * ctx0,
ggml_cgraph * gf,
const std::vector<struct llama_kv_defrag_move> & moves) const;
// TODO: read/write lora adapters and cvec // TODO: read/write lora adapters and cvec
size_t state_write_data(llama_io_write_i & io); size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io); size_t state_read_data (llama_io_read_i & io);
...@@ -199,14 +210,10 @@ private: ...@@ -199,14 +210,10 @@ private:
llama_cparams cparams; llama_cparams cparams;
llama_adapter_cvec cvec; llama_adapter_cvec cvec;
llama_adapter_loras loras; llama_adapter_loras loras;
llama_sbatch sbatch;
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
std::unique_ptr<llama_kv_cache_unified> kv_self; std::unique_ptr<llama_memory_i> memory;
// TODO: remove
bool logits_all = false;
// decode output (2-dimensional array: [n_outputs][n_vocab]) // decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits size_t logits_size = 0; // capacity (of floats) for logits
...@@ -233,6 +240,9 @@ private: ...@@ -233,6 +240,9 @@ private:
ggml_context_ptr ctx_compute; ggml_context_ptr ctx_compute;
// training
ggml_opt_context_t opt_ctx = nullptr;
ggml_threadpool_t threadpool = nullptr; ggml_threadpool_t threadpool = nullptr;
ggml_threadpool_t threadpool_batch = nullptr; ggml_threadpool_t threadpool_batch = nullptr;
......
...@@ -29,8 +29,8 @@ struct llama_cparams { ...@@ -29,8 +29,8 @@ struct llama_cparams {
bool offload_kqv; bool offload_kqv;
bool flash_attn; bool flash_attn;
bool no_perf; bool no_perf;
bool cross_attn;
bool warmup; bool warmup;
bool op_offload;
enum llama_pooling_type pooling_type; enum llama_pooling_type pooling_type;
......
...@@ -284,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { ...@@ -284,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_kv; ++i) { for (uint32_t i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self->head; data[i] = kv_self->s_copy(i);
//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
// prevent out-of-bound sources
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
kv_cell.src = cell_id;
}
data[i] = kv_cell.src;
// TODO: do not mutate the KV cache
// ensure copy only happens once
if (kv_cell.src != (int32_t) cell_id) {
kv_cell.src = cell_id;
}
} }
} }
} }
...@@ -317,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { ...@@ -317,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
// clear unused states // clear unused states
for (int i = 0; i < n_kv; ++i) { for (int i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self->head; data[i] = kv_self->s_mask(i);
//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
data[i] = (float) (kv_cell.src >= 0);
// only clear once
if (kv_cell.src < 0) {
kv_cell.src = cell_id;
}
} }
} }
} }
...@@ -560,12 +532,6 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { ...@@ -560,12 +532,6 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
} }
} }
void llm_graph_input_cross_attn_state::set_input(const llama_ubatch * ubatch) {
if (ubatch->embd) {
ggml_backend_tensor_set(cross_attn_state, ubatch->embd, 0, ggml_nbytes(cross_attn_state));
}
}
// //
// llm_graph_context // llm_graph_context
// //
...@@ -816,7 +782,7 @@ ggml_tensor * llm_graph_context::build_ffn( ...@@ -816,7 +782,7 @@ ggml_tensor * llm_graph_context::build_ffn(
} break; } break;
} }
if (type_gate == LLM_FFN_PAR) { if (gate && type_gate == LLM_FFN_PAR) {
cur = ggml_mul(ctx0, cur, tmp); cur = ggml_mul(ctx0, cur, tmp);
cb(cur, "ffn_gate_par", il); cb(cur, "ffn_gate_par", il);
} }
...@@ -1005,6 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { ...@@ -1005,6 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
//cb(inp->tokens, "inp_tokens", -1); //cb(inp->tokens, "inp_tokens", -1);
ggml_set_input(inp->tokens); ggml_set_input(inp->tokens);
res->t_tokens = inp->tokens;
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens); cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
...@@ -1111,7 +1078,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { ...@@ -1111,7 +1078,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
} }
ggml_tensor * llm_graph_context::build_inp_s_copy() const { ggml_tensor * llm_graph_context::build_inp_s_copy() const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory); const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self); auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
...@@ -1128,7 +1095,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const { ...@@ -1128,7 +1095,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
} }
ggml_tensor * llm_graph_context::build_inp_s_mask() const { ggml_tensor * llm_graph_context::build_inp_s_mask() const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory); const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self); auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
...@@ -1261,8 +1228,19 @@ ggml_tensor * llm_graph_context::build_attn_mha( ...@@ -1261,8 +1228,19 @@ ggml_tensor * llm_graph_context::build_attn_mha(
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
if (v_mla) { if (v_mla) {
#if 0
// v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
// However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens); cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
cur = ggml_mul_mat(ctx0, v_mla, cur); cur = ggml_mul_mat(ctx0, v_mla, cur);
#else
// It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
// The permutations are noops and only change how the tensor data is interpreted.
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_mul_mat(ctx0, v_mla, cur);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
#endif
} }
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
...@@ -1442,8 +1420,6 @@ ggml_tensor * llm_graph_context::build_attn( ...@@ -1442,8 +1420,6 @@ ggml_tensor * llm_graph_context::build_attn(
// store to KV cache // store to KV cache
{ {
GGML_ASSERT(!kv_self->recurrent);
const auto kv_head = kv_self->head; const auto kv_head = kv_self->head;
GGML_ASSERT(kv_self->size == n_ctx); GGML_ASSERT(kv_self->size == n_ctx);
...@@ -1538,25 +1514,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { ...@@ -1538,25 +1514,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
return (llm_graph_input_attn_cross *) res->add_input(std::move(inp)); return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
} }
ggml_tensor * llm_graph_context::build_inp_cross_attn_state() const {
const int64_t n_embd = hparams.n_embd;
auto inp = std::make_unique<llm_graph_input_cross_attn_state>();
ggml_tensor * cur = nullptr;
inp->cross_attn_state = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, 1601, 4);
ggml_set_input(inp->cross_attn_state);
cur = inp->cross_attn_state;
cb(cur, "inp_cross_attn_state", -1);
res->add_input(std::move(inp));
return cur;
}
ggml_tensor * llm_graph_context::build_attn( ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_cross * inp, llm_graph_input_attn_cross * inp,
ggml_cgraph * gf, ggml_cgraph * gf,
...@@ -1612,7 +1569,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state( ...@@ -1612,7 +1569,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
ggml_tensor * state_mask, ggml_tensor * state_mask,
int32_t n_state, int32_t n_state,
int32_t n_seqs) const { int32_t n_seqs) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory); const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto n_kv = kv_self->n; const auto n_kv = kv_self->n;
const auto kv_head = kv_self->head; const auto kv_head = kv_self->head;
...@@ -1644,7 +1601,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ...@@ -1644,7 +1601,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_tensor * state_mask, ggml_tensor * state_mask,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
int il) const { int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory); const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto token_shift_count = hparams.token_shift_count; const auto token_shift_count = hparams.token_shift_count;
...@@ -1665,7 +1622,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( ...@@ -1665,7 +1622,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
ggml_tensor * token_shift, ggml_tensor * token_shift,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
int il) const { int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory); const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto token_shift_count = hparams.token_shift_count; const auto token_shift_count = hparams.token_shift_count;
const auto n_embd = hparams.n_embd; const auto n_embd = hparams.n_embd;
......
...@@ -19,6 +19,7 @@ struct llama_cparams; ...@@ -19,6 +19,7 @@ struct llama_cparams;
class llama_memory_i; class llama_memory_i;
class llama_kv_cache_unified; class llama_kv_cache_unified;
class llama_kv_cache_recurrent;
// certain models (typically multi-modal) can produce different types of graphs // certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type { enum llm_graph_type {
...@@ -86,7 +87,6 @@ public: ...@@ -86,7 +87,6 @@ public:
ggml_tensor * tokens = nullptr; // I32 [n_batch] ggml_tensor * tokens = nullptr; // I32 [n_batch]
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
ggml_tensor * cross_attn_state; // F32 [4, n_embd, 1061]
}; };
class llm_graph_input_pos : public llm_graph_input_i { class llm_graph_input_pos : public llm_graph_input_i {
...@@ -187,26 +187,26 @@ public: ...@@ -187,26 +187,26 @@ public:
class llm_graph_input_s_copy : public llm_graph_input_i { class llm_graph_input_s_copy : public llm_graph_input_i {
public: public:
llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_copy() = default; virtual ~llm_graph_input_s_copy() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size] ggml_tensor * s_copy; // I32 [kv_size]
const llama_kv_cache_unified * kv_self; const llama_kv_cache_recurrent * kv_self;
}; };
class llm_graph_input_s_mask : public llm_graph_input_i { class llm_graph_input_s_mask : public llm_graph_input_i {
public: public:
llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_mask() = default; virtual ~llm_graph_input_s_mask() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_mask; // F32 [1, n_kv] ggml_tensor * s_mask; // F32 [1, n_kv]
const llama_kv_cache_unified * kv_self; const llama_kv_cache_recurrent * kv_self;
}; };
class llm_graph_input_cross_embd : public llm_graph_input_i { class llm_graph_input_cross_embd : public llm_graph_input_i {
...@@ -284,16 +284,6 @@ public: ...@@ -284,16 +284,6 @@ public:
const llama_cross * cross = nullptr; const llama_cross * cross = nullptr;
}; };
class llm_graph_input_cross_attn_state : public llm_graph_input_i {
public:
llm_graph_input_cross_attn_state() = default;
virtual ~llm_graph_input_cross_attn_state() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * cross_attn_state; // F32 [4, n_embd, 1061]
};
// //
// llm_graph_result // llm_graph_result
// //
...@@ -308,6 +298,7 @@ class llm_graph_result_i { ...@@ -308,6 +298,7 @@ class llm_graph_result_i {
public: public:
virtual ~llm_graph_result_i() = default; virtual ~llm_graph_result_i() = default;
virtual ggml_tensor * get_tokens() = 0;
virtual ggml_tensor * get_logits() = 0; virtual ggml_tensor * get_logits() = 0;
virtual ggml_tensor * get_embd() = 0; virtual ggml_tensor * get_embd() = 0;
virtual ggml_tensor * get_embd_pooled() = 0; virtual ggml_tensor * get_embd_pooled() = 0;
...@@ -322,6 +313,7 @@ class llm_graph_result : public llm_graph_result_i { ...@@ -322,6 +313,7 @@ class llm_graph_result : public llm_graph_result_i {
public: public:
virtual ~llm_graph_result() = default; virtual ~llm_graph_result() = default;
ggml_tensor * get_tokens() override { return t_tokens; }
ggml_tensor * get_logits() override { return t_logits; } ggml_tensor * get_logits() override { return t_logits; }
ggml_tensor * get_embd() override { return t_embd; } ggml_tensor * get_embd() override { return t_embd; }
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; } ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
...@@ -338,6 +330,7 @@ public: ...@@ -338,6 +330,7 @@ public:
} }
// important graph nodes // important graph nodes
ggml_tensor * t_tokens = nullptr;
ggml_tensor * t_logits = nullptr; ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr; ggml_tensor * t_embd_pooled = nullptr;
...@@ -361,8 +354,8 @@ struct llm_graph_params { ...@@ -361,8 +354,8 @@ struct llm_graph_params {
const llama_cparams & cparams; const llama_cparams & cparams;
const llama_ubatch & ubatch; const llama_ubatch & ubatch;
ggml_backend_sched * sched; ggml_backend_sched_t sched;
ggml_backend * backend_cpu; ggml_backend_t backend_cpu;
const llama_adapter_cvec * cvec; const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras; const llama_adapter_loras * loras;
...@@ -413,9 +406,9 @@ struct llm_graph_context { ...@@ -413,9 +406,9 @@ struct llm_graph_context {
ggml_context * ctx0 = nullptr; ggml_context * ctx0 = nullptr;
ggml_backend_sched * sched; ggml_backend_sched_t sched;
ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
const llama_adapter_cvec * cvec; const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras; const llama_adapter_loras * loras;
...@@ -502,7 +495,6 @@ struct llm_graph_context { ...@@ -502,7 +495,6 @@ struct llm_graph_context {
ggml_tensor * build_inp_cls() const; ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_s_copy() const; ggml_tensor * build_inp_s_copy() const;
ggml_tensor * build_inp_s_mask() const; ggml_tensor * build_inp_s_mask() const;
ggml_tensor * build_inp_cross_attn_state() const;
ggml_tensor * build_inp_cross_embd() const; ggml_tensor * build_inp_cross_embd() const;
ggml_tensor * build_inp_pos_bucket_enc() const; ggml_tensor * build_inp_pos_bucket_enc() const;
......
...@@ -85,7 +85,3 @@ bool llama_hparams::is_swa(uint32_t il) const { ...@@ -85,7 +85,3 @@ bool llama_hparams::is_swa(uint32_t il) const {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
bool llama_hparams::cross_attention_layers(uint32_t il) const {
return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment