Unverified Commit 1a19df1f authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

update vendored llama.cpp and ggml (#11823)

* TEMPORARY: Update the llama.cpp upstream to my fork's Granite Four branch

This will be redone once my branch is merged upstream in llama.cpp

* feat: Update all patches

There are a number that are no longer needed at all:

- 0003-embeddings: Embeddings entirely overhauled on master
- 0008-ensure-KV-cache-is-fully-defragmented: KV caching entirely
    overhauled on master
- 0019-metal-add-mean-kernel-14267: Merged upstream
- 0020-CUDA-add-mean-operation-14313: Merged upstream

* feat: Sync llama.cpp and ggml

* fix: Update rsync-filter for all moved/new/removed files

* fix: Add files missing from sync

* fix: Update ggml rsync-filter for new ggml-cpu/arch subdirs

* fix: Add ggml files missing from sync

* fix: Narrow llama.cpp rsync-filter to not include mtmd main tool cpp files

* fix: Remove mtmd main cpp files

* fix: Add missing include in sampling_ext.cpp

* fix: Update llama.go to use mtmd instead of clip/llava

* fix: Add patch for mtmd_input_text

* chore: Ignore *.patched in the patch directory

* fix: Fix support for arch-specific ggml-cpu source files with new arrangement

In https://github.com/ggml-org/llama.cpp/pull/13892, all arch-specific
implementations were split out into a nested tree structure under
ggml-cpu/arch. This conflicts with standard CGO layout where all
arch-specific source files are expected to live in the same directory as
the parent go module and use suffixes based on GOOS and GOARCH. As such,
there were really two options for getting this to work:

1. Add a patch on top of the GGML sync to rearrange the files to match the
GO layout convention
2. Use CGO directives to conditionally include the nested source files in
the compilation units

This commit does (2) in order to minimize the set of changes needed on top
of the upstream file layout. To get this to work, there are two key things
needed:

1. In cpu.go, #cgo directives are added to explicitly set __${GOARCH}__ in
the preprocessor directives
2. In arch-impls.c|cpp, use an #ifdef | #elif defined | #endif chain to
explicitly include the .c|.cpp files for the given architecture from the
nested directory

* fix: Use mtmd_helper to correctly load the bitmap for the image

* fix: Apply patch for mtmd_text_input

* fix: Add missing stb to llama.cpp rsync-filter

* fix: Add sync'ed stb vendored header

* fix: Use c++17 and include vendor for go wrapper modules

* fix: Update patch 0015 for upstream implementation of uuid

* feat: Bump to the latest tip of the branch

* fix: Update patches for bump

* feat: Bump back to the cenral repo and point at the latest master

This includes granite 4 and a number of other model architectures!

* fix: Revert changes to ggml export GPU UUID patch

* fix: Add patch for GGML_VERSION and GGML_COMMIT constants

* feat: Sync all patched code

* build: Include cmake/common.cmake in ggml sync

* build: Add top-level include for GNUINstallDirs in CMakeLists.txt

This is used to populate CMAKE_INSTALL_BINDIR

* fix: Add a patch to avoid power throttling API on non-msvc windows builds

* fix: Sync patch changes for ggml-cpu.c

* feat: Bump llama.cpp to 4a4f42

This picks up support for Kimi K2 and PLaMO-2

* feat: Sync llama.cpp

* fix: Handle multi-chunk image encodings from mtmd

* fix: Re-number patches after merge with `main`

* feat: Bump to 41e78c in the makefile

* fix: Fix Solar and argsort/copy patches after bump

* fix: Remove Gemma3n CUDA Graphs patch

It was implemented upstream:
https://github.com/ggml-org/llama.cpp/pull/14741

* feat: Sync llama.cpp / ggml after latest bump

* build: Remove unnecessary CFLAGS definitions in cpu.go

* fix: Remove unnecessary additions in the rsync-filter

* fix: Remove unused vendored code for chat template parsing

* Revert "fix: Remove Gemma3n CUDA Graphs patch"

This reverts commit d724caced3ce21f08924d4b7801f94ce6638f6ea.

* fix: Update 0020 CUDA Graphs for gemma3n to keep both llama.cpp and ollama fixes

https://github.com/ollama/ollama/pull/11195#issuecomment-3137312394



* fix: Sync ggml-cuda.cu after keeping both style cuda graph fixes for gemma3n

* unwind mxfp4 patch

Prepare to bump ggml with their impl for mxfp4

* bump

* fix windows build error

* Convert tensors at load time

Repack the mxfp4 tensors as ggmls kernels expect them to be.

* convert mlp bf16 to f32

* buffer the conversion better

* reshape earlier

* openai swiglu

* add ids

* split qkv, gate_up

* fix nested alt tags

* fast attention

* remove debug messages

* fix lint

* remove redundant test

* remap values only if source/target are different

* add back i32->i32 copy

* refactor cpu quants

* clean up vendor

* update patch instructions

* clean up patches

* remove webgpu

* update mem

* also handle gpt-oss

* revert convert changes

---------
Signed-off-by: default avatarGabe Goodhart <ghart@us.ibm.com>
Co-authored-by: default avatarGabe Goodhart <ghart@us.ibm.com>
Co-authored-by: default avatarDaniel Hiltgen <daniel@ollama.com>
parent 7ccfd97a
......@@ -3,6 +3,7 @@ cmake_minimum_required(VERSION 3.21)
project(Ollama C CXX)
include(CheckLanguage)
include(GNUInstallDirs)
find_package(Threads REQUIRED)
......@@ -51,7 +52,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
add_compile_definitions(NDEBUG)
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
set(GGML_CPU ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
......
UPSTREAM=https://github.com/ggerganov/llama.cpp.git
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
WORKDIR=llama/vendor
FETCH_HEAD=de4c07f93783a1a96456a44dc16b9db538ee1618
FETCH_HEAD=e54d41befcc1575f4c898c5ff4ef43970cead75f
.PHONY: help
help:
......@@ -12,7 +12,7 @@ help:
@echo " clean Clean local repository"
@echo
@echo "Example:"
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync"
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean apply-patches sync"
.PHONY: sync
sync: llama/build-info.cpp ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
......@@ -24,12 +24,12 @@ ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal: ml/backend/ggml/ggml
go generate ./$(@D)
.PHONY: llama/llama.cpp
llama/llama.cpp: llama/vendor/
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
llama/llama.cpp: llama/vendor
rsync -arvzc --delete -f "include LICENSE" -f "merge $@/.rsync-filter" $(addprefix $<,/LICENSE /) $@
.PHONY: ml/backend/ggml/ggml
ml/backend/ggml/ggml: llama/vendor/ggml/
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
ml/backend/ggml/ggml: llama/vendor
rsync -arvzc --delete -f "include LICENSE" -f "merge $@/.rsync-filter" $(addprefix $<,/LICENSE /ggml/) $@
PATCHES=$(wildcard llama/patches/*.patch)
PATCHED=$(join $(dir $(PATCHES)), $(addsuffix ed, $(addprefix ., $(notdir $(PATCHES)))))
......@@ -39,7 +39,15 @@ PATCHED=$(join $(dir $(PATCHES)), $(addsuffix ed, $(addprefix ., $(notdir $(PATC
apply-patches: $(PATCHED)
llama/patches/.%.patched: llama/patches/%.patch
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then \
touch $@; \
else \
echo "Patch failed. Resolve any conflicts then continue."; \
echo "1. Run 'git -C $(WORKDIR) am --continue'"; \
echo "2. Run 'make -f $(lastword $(MAKEFILE_LIST)) format-patches'"; \
echo "3. Run 'make -f $(lastword $(MAKEFILE_LIST)) clean apply-patches'"; \
exit 1; \
fi
.PHONY: checkout
checkout: $(WORKDIR)
......@@ -60,4 +68,5 @@ format-patches: llama/patches
.PHONE: clean
clean: checkout
@git -C $(WORKDIR) am --abort || true
$(RM) llama/patches/.*.patched
......@@ -39,6 +39,7 @@ const (
func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
strings.HasSuffix(t.name, ".bias") ||
t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" ||
......
......@@ -180,7 +180,7 @@ func (kv KV) OllamaEngineRequired() bool {
"llama4",
"mllama",
"qwen25vl",
"gptoss",
"gptoss", "gpt-oss",
}, kv.Architecture())
}
......@@ -328,7 +328,7 @@ func (t TensorType) TypeSize() uint64 {
return 2 + blockSize/2
case TensorTypeQ4_1:
return 2 + 2 + blockSize/2
case TensorTypeMXFP4:
case TensorTypeMXFP4, 39:
return 1 + blockSize/2
case TensorTypeQ5_0:
return 2 + 4 + blockSize/2
......@@ -665,7 +665,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
4*qkvBias.Shape[0],
)
}
case "gptoss":
case "gptoss", "gpt-oss":
kv = make([]uint64, f.KV().BlockCount())
for i := range kv {
kv[i] = uint64(float64((embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
......@@ -675,8 +675,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
kv[i] *= context
}
}
fullOffload = 4 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
partialOffload = fullOffload
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
}
return
......@@ -761,10 +760,6 @@ func (f GGML) SupportsFlashAttention() bool {
return false
}
if f.KV().Architecture() == "gptoss" {
return false
}
// Check head counts match and are non-zero
headCountK := f.KV().EmbeddingHeadCountK()
headCountV := f.KV().EmbeddingHeadCountV()
......
int LLAMA_BUILD_NUMBER = 0;
char const *LLAMA_COMMIT = "de4c07f93783a1a96456a44dc16b9db538ee1618";
char const *LLAMA_COMMIT = "e54d41befcc1575f4c898c5ff4ef43970cead75f";
char const *LLAMA_COMPILER = "";
char const *LLAMA_BUILD_TARGET = "";
protect **/*.go
include common/
include common/base64.*
include common/common.*
include common/json-schema-to-grammar.*
include common/json.*
include common/log.*
include common/sampling.*
include common/stb_image.*
include include/
include include/llama.*
include include/llama-*.*
include tools/
include tools/mtmd/
include tools/mtmd/clip.*
include tools/mtmd/clip-impl.*
include tools/mtmd/llava.*
include src/
include src/llama.*
include src/llama-*.*
include src/unicode-data.*
include src/unicode.*
exclude *
protect .rsync-filter
protect *.go
include /common/
include /common/base64.*
include /common/common.*
include /common/json-schema-to-grammar.*
include /common/json.*
include /common/log.*
include /common/sampling.*
include /include/
include /include/llama.*
include /include/llama-*.*
include /tools/
include /tools/mtmd/
include /tools/mtmd/*.h
include /tools/mtmd/clip.cpp
include /tools/mtmd/mtmd.cpp
include /tools/mtmd/mtmd-audio.cpp
include /tools/mtmd/mtmd-helper.cpp
include /src/
include /src/llama.*
include /src/llama-*.*
include /src/unicode-data.*
include /src/unicode.*
include /vendor/
include /vendor/miniaudio/
include /vendor/miniaudio/*.h
include /vendor/nlohmann/
include /vendor/nlohmann/*.hpp
include /vendor/stb/
include /vendor/stb/*.h
hide *
......@@ -203,6 +203,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
DWORD p = NORMAL_PRIORITY_CLASS;
switch (prio) {
case GGML_SCHED_PRIO_LOW: p = BELOW_NORMAL_PRIORITY_CLASS; break;
case GGML_SCHED_PRIO_NORMAL: p = NORMAL_PRIORITY_CLASS; break;
case GGML_SCHED_PRIO_MEDIUM: p = ABOVE_NORMAL_PRIORITY_CLASS; break;
case GGML_SCHED_PRIO_HIGH: p = HIGH_PRIORITY_CLASS; break;
......@@ -228,6 +229,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
int p = 0;
switch (prio) {
case GGML_SCHED_PRIO_LOW: p = 5; break;
case GGML_SCHED_PRIO_NORMAL: p = 0; break;
case GGML_SCHED_PRIO_MEDIUM: p = -5; break;
case GGML_SCHED_PRIO_HIGH: p = -10; break;
......@@ -443,9 +445,37 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder);
}
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}
bool string_remove_suffix(std::string & str, const std::string_view & suffix) {
bool has_suffix = string_ends_with(str, suffix);
if (has_suffix) {
str = str.substr(0, str.size() - suffix.size());
}
return has_suffix;
}
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
if (!str.empty() && !stop.empty()) {
const char text_last_char = str.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const auto current_partial = stop.substr(0, char_index + 1);
if (string_ends_with(str, current_partial)) {
return str.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}
std::string regex_escape(const std::string & s) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
return std::regex_replace(s, special_chars, "\\$0");
return std::regex_replace(s, special_chars, "\\$&");
}
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
......@@ -685,11 +715,17 @@ bool fs_validate_filename(const std::string & filename) {
// disable C++17 deprecation warning for std::codecvt_utf8
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
#if defined(__clang__)
# pragma clang diagnostic pop
#elif defined(__GNUC__)
# pragma GCC diagnostic pop
#endif
filename_utf32 = converter.from_bytes(filename);
......@@ -746,6 +782,9 @@ bool fs_validate_filename(const std::string & filename) {
return true;
}
#include <iostream>
// returns true if successful, false otherwise
bool fs_create_directory_with_parents(const std::string & path) {
#ifdef _WIN32
......@@ -763,9 +802,16 @@ bool fs_create_directory_with_parents(const std::string & path) {
// process path from front to back, procedurally creating directories
while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
const std::wstring subpath = wpath.substr(0, pos_slash);
const wchar_t * test = subpath.c_str();
const bool success = CreateDirectoryW(test, NULL);
pos_slash += 1;
// skip the drive letter, in some systems it can return an access denied error
if (subpath.length() == 2 && subpath[1] == ':') {
continue;
}
const bool success = CreateDirectoryW(subpath.c_str(), NULL);
if (!success) {
const DWORD error = GetLastError();
......@@ -779,8 +825,6 @@ bool fs_create_directory_with_parents(const std::string & path) {
return false;
}
}
pos_slash += 1;
}
return true;
......@@ -830,7 +874,7 @@ std::string fs_get_cache_directory() {
if (getenv("LLAMA_CACHE")) {
cache_directory = std::getenv("LLAMA_CACHE");
} else {
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX)
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
if (std::getenv("XDG_CACHE_HOME")) {
cache_directory = std::getenv("XDG_CACHE_HOME");
} else {
......@@ -876,31 +920,6 @@ struct common_init_result common_init_from_params(common_params & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);
if (params.reranking) {
bool ok = true;
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
ok = false;
}
if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__);
ok = false;
}
if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
ok = false;
}
if (!ok) {
llama_model_free(model);
return iparams;
}
}
auto cparams = common_context_params_to_llama(params);
llama_context * lctx = llama_init_from_model(model, cparams);
......@@ -910,7 +929,7 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}
if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
params.ctx_shift = false;
}
......@@ -942,6 +961,35 @@ struct common_init_result common_init_from_params(common_params & params) {
}
}
if (llama_pooling_type(lctx) == LLAMA_POOLING_TYPE_RANK) {
bool ok = true;
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
ok = false;
}
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
if (!has_eos && !has_sep) {
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
ok = false;
} else if (!has_eos) {
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
} else if (!has_sep) {
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
ok = false;
}
if (!ok) {
llama_free(lctx);
llama_model_free(model);
return iparams;
}
}
// load and optionally apply lora adapters
for (auto & la : params.lora_adapters) {
llama_adapter_lora_ptr lora;
......@@ -966,15 +1014,21 @@ struct common_init_result common_init_from_params(common_params & params) {
params.sampling.ignore_eos = false;
}
if (params.sampling.ignore_eos) {
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias.push_back({i, -INFINITY});
}
// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}
if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
}
if (params.sampling.penalty_last_n == -1) {
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sampling.penalty_last_n = llama_n_ctx(lctx);
......@@ -1017,7 +1071,7 @@ struct common_init_result common_init_from_params(common_params & params) {
if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
}
llama_kv_self_clear(lctx);
llama_memory_clear(llama_get_memory(lctx), true);
llama_synchronize(lctx);
llama_perf_context_reset(lctx);
llama_set_warmup(lctx, false);
......@@ -1068,6 +1122,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;
mparams.use_extra_bufts = !params.no_extra_bufts;
if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL;
......@@ -1083,6 +1138,9 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
}
mparams.progress_callback = params.load_progress_callback;
mparams.progress_callback_user_data = params.load_progress_callback_user_data;
return mparams;
}
......@@ -1114,11 +1172,8 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
if (params.reranking) {
cparams.embeddings = true;
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
}
cparams.swa_full = params.swa_full;
cparams.kv_unified = params.kv_unified;
cparams.type_k = params.cache_type_k;
cparams.type_v = params.cache_type_v;
......@@ -1252,6 +1307,9 @@ std::vector<llama_token> common_tokenize(
int n_tokens = text.length() + 2 * add_special;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
if (n_tokens == std::numeric_limits<int32_t>::min()) {
throw std::runtime_error("Tokenization failed: input text too large, tokenization result exceeds int32_t limit");
}
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
......@@ -1306,81 +1364,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
return text;
}
//
// KV cache utils
//
void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) {
static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
llama_kv_cache_view_cell * c_curr = view.cells;
llama_seq_id * cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
if (i % row_size == 0) {
printf("\n%5d: ", i);
}
int seq_count = 0;
for (int j = 0; j < view.n_seq_max; j++) {
if (cs_curr[j] >= 0) { seq_count++; }
}
putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
}
printf("\n=== Done dumping\n");
}
void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) {
static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
std::unordered_map<llama_seq_id, size_t> seqs;
llama_kv_cache_view_cell * c_curr = view.cells;
llama_seq_id * cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
for (int j = 0; j < view.n_seq_max; j++) {
if (cs_curr[j] < 0) { continue; }
if (seqs.find(cs_curr[j]) == seqs.end()) {
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
const size_t sz = seqs.size();
seqs[cs_curr[j]] = sz;
}
}
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
}
printf("=== Sequence legend: ");
for (const auto & it : seqs) {
printf("%zu=%d, ", it.second, it.first);
}
printf("'+'=other sequence ids");
c_curr = view.cells;
cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
if (i % row_size == 0) {
printf("\n%5d: ", i);
}
for (int j = 0; j < view.n_seq_max; j++) {
if (cs_curr[j] >= 0) {
const auto & it = seqs.find(cs_curr[j]);
putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
} else {
putchar('.');
}
}
putchar(' ');
}
printf("\n=== Done dumping\n");
}
//
// Embedding utils
//
......
package common
// #cgo CXXFLAGS: -std=c++11
// #cgo CPPFLAGS: -I${SRCDIR}/../include
// #cgo CXXFLAGS: -std=c++17
// #cgo CPPFLAGS: -I${SRCDIR}/../include -I${SRCDIR}/../vendor
// #cgo CPPFLAGS: -I${SRCDIR}/../../../ml/backend/ggml/ggml/include
import "C"
......@@ -6,7 +6,9 @@
#include <set>
#include <string>
#include <string_view>
#include <vector>
#include <map>
#include <sstream>
#ifdef _WIN32
......@@ -75,10 +77,11 @@ enum llama_example {
LLAMA_EXAMPLE_SERVER,
LLAMA_EXAMPLE_CVECTOR_GENERATOR,
LLAMA_EXAMPLE_EXPORT_LORA,
LLAMA_EXAMPLE_LLAVA,
LLAMA_EXAMPLE_MTMD,
LLAMA_EXAMPLE_LOOKUP,
LLAMA_EXAMPLE_PARALLEL,
LLAMA_EXAMPLE_TTS,
LLAMA_EXAMPLE_DIFFUSION,
LLAMA_EXAMPLE_COUNT,
};
......@@ -114,7 +117,7 @@ enum common_grammar_trigger_type {
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
};
struct common_grammar_trigger {
......@@ -175,7 +178,8 @@ struct common_params_sampling {
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
std::set<llama_token> preserved_tokens;
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
// print the parameters into a string
std::string print() const;
......@@ -197,6 +201,10 @@ struct common_params_speculative {
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
......@@ -212,9 +220,26 @@ struct common_params_vocoder {
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
};
struct common_params_diffusion {
int32_t steps = 128;
bool visual_mode = false;
float eps = 0; // epsilon for timesteps
int32_t block_length = 0; // block length for generation
int32_t algorithm = 4; // default algorithm: low-confidence
float alg_temp = 0.0f; // algorithm temperature
float cfg_scale = 0; // classifier-free guidance scale
bool add_gumbel_noise = false; // add gumbel noise to the logits if temp > 0.0
};
enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`
COMMON_REASONING_FORMAT_AUTO,
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
};
struct common_params {
......@@ -262,6 +287,7 @@ struct common_params {
struct common_params_sampling sampling;
struct common_params_speculative speculative;
struct common_params_vocoder vocoder;
struct common_params_diffusion diffusion;
struct common_params_model model;
......@@ -290,6 +316,7 @@ struct common_params {
int32_t verbosity = 0;
int32_t control_vector_layer_start = -1; // layer range for control vector
int32_t control_vector_layer_end = -1; // layer range for control vector
bool offline = false;
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
......@@ -322,17 +349,19 @@ struct common_params {
bool flash_attn = false; // flash attention
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
bool kv_unified = false; // enable unified KV cache
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory
bool verbose_prompt = false; // print prompt tokens before generation
bool display_prompt = true; // print prompt before generation
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
bool no_kv_offload = false; // disable KV offloading
bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data
bool no_op_offload = false; // globally disable offload host tensor operations to device
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
bool single_turn = false; // single turn chat conversation
......@@ -352,7 +381,7 @@ struct common_params {
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
std::string embd_sep = "\n"; // separator of embeddings
bool reranking = false; // enable reranking support on server
std::string cls_sep = "\t"; // separator of classification sequences
// server params
int32_t port = 8080; // server listens on this network port
......@@ -363,16 +392,21 @@ struct common_params {
std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
std::string api_prefix = ""; // NOLINT
std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
int reasoning_budget = -1;
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
std::vector<std::string> api_keys;
std::string ssl_file_key = ""; // NOLINT
std::string ssl_file_cert = ""; // NOLINT
std::map<std::string, std::string> default_template_kwargs;
// "advanced" endpoints are disabled by default for better security
bool webui = true;
bool endpoint_slots = false;
......@@ -407,10 +441,12 @@ struct common_params {
int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
int32_t i_chunk = 0; // start processing from this chunk
int8_t imat_dat = 0; // whether the legacy imatrix.dat format should be output (gguf <= 0 < dat)
bool process_output = false; // collect data for the output tensor
bool compute_ppl = true; // whether to compute perplexity
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
bool process_output = false; // collect data for the output tensor
bool compute_ppl = true; // whether to compute perplexity
bool show_statistics = false; // show imatrix statistics per tensor
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
// cvector-generator params
int n_pca_batch = 100;
......@@ -426,6 +462,11 @@ struct common_params {
// common params
std::string out_file; // output filename for all example programs
// optional callback for model loading progress and cancellation:
// called with a progress value between 0.0 and 1.0.
// return false from callback to abort model loading or true to continue
llama_progress_callback load_progress_callback = NULL;
void * load_progress_callback_user_data = NULL;
};
// call once at the start of a program if it uses libcommon
......@@ -503,10 +544,10 @@ static bool string_starts_with(const std::string & str,
return str.rfind(prefix, 0) == 0;
}
static bool string_ends_with(const std::string & str,
const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}
// While we wait for C++20's std::string::ends_with...
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
bool string_remove_suffix(std::string & str, const std::string_view & suffix);
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input);
......@@ -615,16 +656,6 @@ std::string common_detokenize(
const std::vector<llama_token> & tokens,
bool special = true);
//
// KV cache utils
//
// Dump the KV cache view with the number of sequences per cell.
void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80);
// Dump the KV cache view showing individual sequences in each cell (long output).
void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
//
// Embedding utils
//
......
#include "json-schema-to-grammar.h"
#include "common.h"
#include <nlohmann/json.hpp>
#include <algorithm>
#include <fstream>
#include <map>
#include <regex>
#include <sstream>
......@@ -40,49 +41,6 @@ static std::string build_repetition(const std::string & item_rule, int min_items
return result;
}
/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
class string_view {
const std::string & _str;
const size_t _start;
const size_t _end;
public:
string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
size_t size() const {
return _end - _start;
}
size_t length() const {
return size();
}
operator std::string() const {
return str();
}
std::string str() const {
return _str.substr(_start, _end - _start);
}
string_view substr(size_t pos, size_t len = std::string::npos) const {
return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
}
char operator[](size_t pos) const {
auto index = _start + pos;
if (index >= _end) {
throw std::out_of_range("string_view index out of range");
}
return _str[_start + pos];
}
bool operator==(const string_view & other) const {
std::string this_str = *this;
std::string other_str = other;
return this_str == other_str;
}
};
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
auto has_min = min_value != std::numeric_limits<int>::min();
auto has_max = max_value != std::numeric_limits<int>::max();
......@@ -111,14 +69,14 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
}
out << "}";
};
std::function<void(const string_view &, const string_view &)> uniform_range =
[&](const string_view & from, const string_view & to) {
std::function<void(const std::string_view &, const std::string_view &)> uniform_range =
[&](const std::string_view & from, const std::string_view & to) {
size_t i = 0;
while (i < from.length() && i < to.length() && from[i] == to[i]) {
i++;
}
if (i > 0) {
out << "\"" << from.substr(0, i).str() << "\"";
out << "\"" << from.substr(0, i) << "\"";
}
if (i < from.length() && i < to.length()) {
if (i > 0) {
......
#pragma once
#include "ggml.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include <nlohmann/json_fwd.hpp>
#include <functional>
#include <string>
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
bool force_gbnf = false);
......
......@@ -161,7 +161,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
} else {
std::vector<std::string> patterns_at_start;
std::vector<std::string> trigger_patterns;
std::vector<std::string> patterns_anywhere;
std::vector<llama_token> trigger_tokens;
for (const auto & trigger : params.grammar_triggers) {
......@@ -173,10 +173,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
{
const auto & pattern = trigger.value;
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
patterns_anywhere.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{
trigger_patterns.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
......@@ -190,10 +193,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}
}
std::vector<std::string> trigger_patterns;
if (!patterns_at_start.empty()) {
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
}
if (!patterns_anywhere.empty()) {
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
}
......
......@@ -61,59 +61,23 @@ extern "C" {
struct llama_model;
struct llama_context;
struct llama_sampler;
struct llama_kv_cache;
typedef struct llama_memory_i * llama_memory_t;
struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
typedef int32_t llama_pos;
typedef int32_t llama_token;
typedef int32_t llama_seq_id;
enum llama_vocab_type {
LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
};
// pre-tokenization types
enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming
};
enum llama_rope_type {
......@@ -188,6 +152,7 @@ extern "C" {
//LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // removed from gguf files, use Q4_0 and runtime repack
LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors
LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};
......@@ -240,18 +205,21 @@ extern "C" {
typedef bool (*llama_progress_callback)(float progress, void * user_data);
// Input data for llama_decode
// Input data for llama_encode/llama_decode
// A llama_batch object can contain input about one or many sequences
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
//
// - token : the token ids of the input (used when embd is NULL)
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
// - pos : the positions of the respective token in the sequence
// (if set to NULL, the token position will be tracked automatically by llama_decode)
// (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode)
// - seq_id : the sequence to which the respective token belongs
// (if set to NULL, the sequence ID will be assumed to be 0)
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
// (if set to NULL, only the logits for last token will be returned)
// (if set to NULL:
// - if embeddings: all tokens are output
// - if not: only the last token is output
// )
//
typedef struct llama_batch {
int32_t n_tokens;
......@@ -261,7 +229,7 @@ extern "C" {
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output"
int8_t * logits; // TODO: rename this to "output"
} llama_batch;
enum llama_model_kv_override_type {
......@@ -317,10 +285,11 @@ extern "C" {
const struct llama_model_kv_override * kv_overrides;
// Keep the booleans together to avoid misalignment during copy-by-value.
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM
bool check_tensors; // validate model tensor data
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM
bool check_tensors; // validate model tensor data
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
};
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
......@@ -345,7 +314,7 @@ extern "C" {
float yarn_beta_fast; // YaRN low correction dim
float yarn_beta_slow; // YaRN high correction dim
uint32_t yarn_orig_ctx; // YaRN original context size
float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default)
float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default)
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
......@@ -361,10 +330,16 @@ extern "C" {
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool no_perf; // whether to measure performance timings
bool op_offload; // whether to offload host tensor operations to device
bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // use flash attention [EXPERIMENTAL]
bool no_perf; // measure performance timings
bool op_offload; // offload host tensor operations to device
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
};
// model quantization parameters
......@@ -381,6 +356,7 @@ extern "C" {
void * imatrix; // pointer to importance matrix data
void * kv_overrides; // pointer to vector containing overrides
void * tensor_types; // pointer to vector containing tensor types
void * prune_layers; // pointer to vector containing layer indices to prune
} llama_model_quantize_params;
typedef struct llama_logit_bias {
......@@ -470,6 +446,7 @@ extern "C" {
LLAMA_API int64_t llama_time_us(void);
LLAMA_API size_t llama_max_devices(void);
LLAMA_API size_t llama_max_parallel_sequences(void);
LLAMA_API bool llama_supports_mmap (void);
LLAMA_API bool llama_supports_mlock (void);
......@@ -489,9 +466,11 @@ extern "C" {
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
......@@ -500,10 +479,18 @@ extern "C" {
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
// Get the model's RoPE frequency scaling factor
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
// Returns the number of classifier outputs (only valid for classifier models)
// Undefined behavior for non-classifier models
LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model);
// Returns label of classifier output by index (<n_cls_out). Returns nullptr if no label provided
LLAMA_API const char * llama_model_cls_label(const struct llama_model * model, uint32_t i);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
......@@ -552,6 +539,9 @@ extern "C" {
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
// Returns 0 on success
LLAMA_API uint32_t llama_model_quantize(
const char * fname_inp,
......@@ -604,210 +594,190 @@ extern "C" {
int32_t il_end);
//
// KV cache
// Memory
//
// TODO: start using struct llama_kv_cache
// Information associated with an individual cell in the KV cache view.
struct llama_kv_cache_view_cell {
// The position for this cell. Takes KV cache shifts into account.
// May be negative if the cell is not populated.
llama_pos pos;
};
// An updateable view of the KV cache.
struct llama_kv_cache_view {
// Number of KV cache cells. This will be the same as the context size.
int32_t n_cells;
// Clear the memory contents
// If data == true, the data buffers will also be cleared together with the metadata
LLAMA_API void llama_memory_clear(
llama_memory_t mem,
bool data);
// Maximum number of sequences that can exist in a cell. It's not an error
// if there are more sequences in a cell than this value, however they will
// not be visible in the view cells_sequences.
int32_t n_seq_max;
// Number of tokens in the cache. For example, if there are two populated
// cells, the first with 1 sequence id in it and the second with 2 sequence
// ids then you'll have 3 tokens.
int32_t token_count;
// Number of populated cache cells.
int32_t used_cells;
// Maximum contiguous empty slots in the cache.
int32_t max_contiguous;
// Index to the start of the max_contiguous slot range. Can be negative
// when cache is full.
int32_t max_contiguous_idx;
// Information for an individual cell.
struct llama_kv_cache_view_cell * cells;
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
// seq_id < 0 : match any sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API bool llama_memory_seq_rm(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
// The sequences for each cell. There will be n_seq_max items per cell.
llama_seq_id * cells_sequences;
};
// Copy all tokens that belong to the specified sequence to another sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_memory_seq_cp(
llama_memory_t mem,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
// Create an empty KV cache view. (use only for debugging purposes)
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
// Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_memory_seq_keep(
llama_memory_t mem,
llama_seq_id seq_id);
// Free a KV cache view. (use only for debugging purposes)
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_memory_seq_add(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
// TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx)
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
// Integer division of the positions by factor of `d > 1`
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_memory_seq_div(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
// Returns the smallest position present in the memory for the specified sequence
// This is typically non-zero only for SWA caches
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_memory_seq_pos_min(
llama_memory_t mem,
llama_seq_id seq_id);
// Returns the largest position present in the memory for the specified sequence
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_memory_seq_pos_max(
llama_memory_t mem,
llama_seq_id seq_id);
// Check if the memory supports shifting
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
///
//
// KV cache for self-attention (TODO: deprecate in favor of llama_memory)
//
// Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
"use llama_kv_self_n_tokens instead");
DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
"use llama_kv_self_used_cells instead");
DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
// Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_self_clear(
struct llama_context * ctx);
DEPRECATED(LLAMA_API void llama_kv_self_clear(
struct llama_context * ctx),
"Use llama_memory_clear() instead");
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
// seq_id < 0 : match any sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API bool llama_kv_self_seq_rm(
DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
llama_pos p1),
"Use llama_memory_seq_rm() instead");
// Copy all tokens that belong to the specified sequence to another sequence
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_self_seq_cp(
DEPRECATED(LLAMA_API void llama_kv_self_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
llama_pos p1),
"Use llama_memory_seq_cp() instead");
// Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_kv_self_seq_keep(
DEPRECATED(LLAMA_API void llama_kv_self_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id);
llama_seq_id seq_id),
"Use llama_memory_seq_keep() instead");
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly:
// - lazily on next llama_decode()
// - explicitly with llama_kv_self_update()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_self_seq_add(
DEPRECATED(LLAMA_API void llama_kv_self_seq_add(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);
llama_pos delta),
"Use llama_memory_seq_add() instead");
// Integer division of the positions by factor of `d > 1`
// If the KV cache is RoPEd, the KV data is updated accordingly:
// - lazily on next llama_decode()
// - explicitly with llama_kv_self_update()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_self_seq_div(
DEPRECATED(LLAMA_API void llama_kv_self_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
int d),
"Use llama_memory_seq_div() instead");
// Returns the smallest position present in the KV cache for the specified sequence
// This is typically non-zero only for SWA caches
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
// Return -1 if the sequence is empty
DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min(
struct llama_context * ctx,
llama_seq_id seq_id),
"Use llama_memory_seq_pos_min() instead");
// Returns the largest position present in the KV cache for the specified sequence
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
// Return -1 if the sequence is empty
DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id);
llama_seq_id seq_id),
"Use llama_memory_seq_pos_max() instead");
// Defragment the KV cache
// This will be applied:
// - lazily on next llama_decode()
// - explicitly with llama_kv_self_update()
LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
// Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx),
"use llama_memory_can_shift() instead");
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
DEPRECATED(LLAMA_API void llama_kv_cache_clear(
struct llama_context * ctx),
"use llama_kv_self_clear instead");
DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1),
"use llama_kv_self_seq_rm instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1),
"use llama_kv_self_seq_cp instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_kv_self_seq_keep instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta),
"use llama_kv_self_seq_add instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d),
"use llama_kv_self_seq_div instead");
DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_kv_self_seq_pos_max instead");
DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
"use llama_kv_self_defrag instead");
DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
"use llama_kv_self_can_shift instead");
DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
"use llama_kv_self_update instead");
DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
"simply remove this call, updates are applied lazily on the next llama_decode()");
//
// State / sessions
//
// Returns the *actual* size in bytes of the state
// (logits, embedding and kv_cache)
// (logits, embedding and memory)
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
......@@ -863,12 +833,12 @@ extern "C" {
size_t n_token_count),
"use llama_state_save_file instead");
// Get the exact size needed to copy the KV cache of a single sequence
// Get the exact size needed to copy the state of a single sequence
LLAMA_API size_t llama_state_seq_get_size(
struct llama_context * ctx,
llama_seq_id seq_id);
// Copy the KV cache of a single sequence into the specified buffer
// Copy the state of a single sequence into the specified buffer
LLAMA_API size_t llama_state_seq_get_data(
struct llama_context * ctx,
uint8_t * dst,
......@@ -934,18 +904,23 @@ extern "C" {
// For encode-decoder contexts, processes the batch using the encoder.
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
// 0 - success
// < 0 - error. the KV cache state is restored to the state before this call
// < 0 - error. the memory state is restored to the state before this call
LLAMA_API int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch);
// Process a batch of tokens.
// Requires KV cache.
// Requires the context to have a memory.
// For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning.
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// < 0 - error. the KV cache state is restored to the state before this call
// Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context
// To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
// Upon other return values, the memory state is restored to the state before this call
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// 2 - aborted (processed ubatches will remain in the context's memory)
// -1 - invalid input batch
// < -1 - fatal error (processed ubatches will remain in the context's memory)
LLAMA_API int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch);
......@@ -961,8 +936,8 @@ extern "C" {
// Get the number of threads used for prompt and batch processing (multiple token).
LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
// Set whether the model is in embeddings mode or not
// If true, embeddings will be returned but logits will not
// Set whether the context outputs embeddings or not
// TODO: rename to avoid confusion with llama_get_embeddings()
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
// Set whether to use causal attention or not
......@@ -986,6 +961,7 @@ extern "C" {
// in the order they have appeared in the batch.
// Rows: number of tokens for which llama_batch.logits[i] != 0
// Cols: n_vocab
// TODO: deprecate in favor of llama_get_logits_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
// Logits for the ith token. For positive indices, Equivalent to:
......@@ -1000,6 +976,7 @@ extern "C" {
// in the order they have appeared in the batch.
// shape: [n_outputs*n_embd]
// Otherwise, returns NULL.
// TODO: deprecate in favor of llama_get_embeddings_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Get the embeddings for the ith token. For positive indices, Equivalent to:
......@@ -1011,7 +988,7 @@ extern "C" {
// Get the embeddings for a sequence id
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence
// otherwise: float[n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
......@@ -1038,9 +1015,11 @@ extern "C" {
LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
LLAMA_API llama_token llama_vocab_mask(const struct llama_vocab * vocab); // mask
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
LLAMA_API bool llama_vocab_get_add_sep(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);
......@@ -1084,6 +1063,7 @@ extern "C" {
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
/// @return Returns the number of tokens on success, no more than n_tokens_max
/// @return Returns a negative number on failure - the number of tokens that would have been returned
/// @return Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit)
/// @param add_special Allow to add BOS and EOS tokens if model is configured to do so.
/// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
/// as plaintext. Does not insert a leading space.
......@@ -1421,6 +1401,7 @@ extern "C" {
int32_t n_p_eval;
int32_t n_eval;
int32_t n_reused; // number of times a ggml compute graph had been reused
};
struct llama_perf_sampler_data {
......
......@@ -20,6 +20,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_BERT, "bert" },
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
{ LLM_ARCH_NEO_BERT, "neo-bert" },
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
{ LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" },
......@@ -33,6 +34,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PHIMOE, "phimoe" },
{ LLM_ARCH_PLAMO, "plamo" },
{ LLM_ARCH_PLAMO2, "plamo2" },
{ LLM_ARCH_CODESHELL, "codeshell" },
{ LLM_ARCH_ORION, "orion" },
{ LLM_ARCH_INTERNLM2, "internlm2" },
......@@ -41,8 +43,12 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_MAMBA2, "mamba2" },
{ LLM_ARCH_JAMBA, "jamba" },
{ LLM_ARCH_FALCON_H1, "falcon-h1" },
{ LLM_ARCH_XVERSE, "xverse" },
{ LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_COHERE2, "cohere2" },
......@@ -56,23 +62,38 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
{ LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
{ LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" },
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_EXAONE4, "exaone4" },
{ LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
{ LLM_ARCH_RWKV7, "rwkv7" },
{ LLM_ARCH_ARWKV7, "arwkv7" },
{ LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_GRANITE_HYBRID, "granitehybrid" },
{ LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_SOLAR, "solar" },
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_PLM, "plm" },
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
{ LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" },
{ LLM_ARCH_SMOLLM3, "smollm3" },
{ LLM_ARCH_OPENAI_MOE, "gpt-oss" },
{ LLM_ARCH_LFM2, "lfm2" },
{ LLM_ARCH_DREAM, "dream" },
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
{ LLM_ARCH_LLADA, "llada" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
......@@ -109,6 +130,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
......@@ -166,6 +188,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
{ LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" },
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
......@@ -176,6 +199,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" },
{ LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" },
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
{ LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
......@@ -194,13 +221,13 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
{ LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" },
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
......@@ -244,6 +271,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_ARCEE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_LLAMA4,
{
......@@ -450,6 +495,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
{ LLM_TENSOR_POS_EMBD, "position_embd" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
......@@ -494,6 +540,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_NEO_BERT,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
{ LLM_TENSOR_CLS, "cls" },
{ LLM_TENSOR_CLS_OUT, "cls.output" },
},
},
{
LLM_ARCH_JINA_BERT_V2,
{
......@@ -735,6 +796,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_PLAMO2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
{ LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" },
{ LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" },
{ LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_CODESHELL,
{
......@@ -894,6 +985,42 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_GEMMA3N,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
{ LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" },
{ LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" },
{ LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" },
{ LLM_TENSOR_ALTUP_UNEMBD_PROJ, "altup_unembd_proj" },
{ LLM_TENSOR_ALTUP_PROJ, "altup_proj" },
{ LLM_TENSOR_PER_LAYER_INP_GATE, "blk.%d.inp_gate" },
{ LLM_TENSOR_PER_LAYER_PROJ, "blk.%d.proj" },
{ LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" },
{ LLM_TENSOR_ALTUP_CORRECT_COEF, "blk.%d.altup_correct_coef" },
{ LLM_TENSOR_ALTUP_CORRECT_SCALE, "blk.%d.altup_correct_scale" },
{ LLM_TENSOR_ALTUP_PREDICT_COEF, "blk.%d.altup_predict_coef" },
{ LLM_TENSOR_ALTUP_ROUTER, "blk.%d.altup_router" },
{ LLM_TENSOR_ALTUP_ROUTER_NORM, "blk.%d.altup_router_norm" },
{ LLM_TENSOR_LAUREL_L, "blk.%d.laurel_l" },
{ LLM_TENSOR_LAUREL_R, "blk.%d.laurel_r" },
{ LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" },
},
},
{
LLM_ARCH_STARCODER2,
{
......@@ -928,6 +1055,77 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{
LLM_ARCH_MAMBA2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{
LLM_ARCH_JAMBA,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" },
{ LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_FALCON_H1,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_XVERSE,
{
......@@ -1198,6 +1396,40 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_GLM4_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
// NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number)
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
{ LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" },
{ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" },
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
},
},
{
LLM_ARCH_BITNET,
{
......@@ -1321,6 +1553,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_EXAONE4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
}
},
{
LLM_ARCH_RWKV6,
{
......@@ -1483,6 +1735,46 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_GRANITE_HYBRID,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
// mamba(2) ssm layers
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
// attention layers
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
// dense FFN
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
// moe FFN
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
// shared expert
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
......@@ -1570,6 +1862,231 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_DOTS1,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
}
},
{
LLM_ARCH_ERNIE4_5,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_ERNIE4_5_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_HUNYUAN_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_HUNYUAN_DENSE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_SMOLLM3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_OPENAI_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_SINKS, "blk.%d.attn_sinks" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_LFM2,
{
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" },
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
}
},
{
LLM_ARCH_SMALLTHINKER,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }
},
},
{
LLM_ARCH_DREAM,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_LLADA,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_UNKNOWN,
{
......@@ -1609,6 +2126,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_SINKS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SCALE}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
......@@ -1654,7 +2172,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
{LLM_TENSOR_SSM_DT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_B_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
......@@ -1698,6 +2220,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
// altup / laurel (gemma 3n)
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_CORRECT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_CORRECT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_PREDICT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_ROUTER, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_ROUTER_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_LAUREL_L, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LAUREL_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
// this tensor is loaded for T5, but never used
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
{LLM_TENSOR_BSKCN_TV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
......@@ -1717,13 +2256,30 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
{LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
// These tensors only exist in the last layer(s) and are treated as output tensors
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
};
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
std::string LLM_KV::operator()(llm_kv kv) const {
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
std::string name = ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
if (suffix != nullptr) {
name += ".";
name += suffix;
}
return name;
}
std::string LLM_TN_IMPL::str() const {
......@@ -1762,3 +2318,40 @@ llm_arch llm_arch_from_string(const std::string & name) {
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
return LLM_TENSOR_INFOS.at(tensor);
}
bool llm_arch_is_recurrent(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_MAMBA:
case LLM_ARCH_MAMBA2:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
return true;
default:
return false;
}
}
bool llm_arch_is_hybrid(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_JAMBA:
case LLM_ARCH_FALCON_H1:
case LLM_ARCH_PLAMO2:
case LLM_ARCH_GRANITE_HYBRID:
case LLM_ARCH_LFM2:
return true;
default:
return false;
}
}
bool llm_arch_is_diffusion(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA:
return true;
default:
return false;
}
}
......@@ -24,6 +24,7 @@ enum llm_arch {
LLM_ARCH_BERT,
LLM_ARCH_NOMIC_BERT,
LLM_ARCH_NOMIC_BERT_MOE,
LLM_ARCH_NEO_BERT,
LLM_ARCH_JINA_BERT_V2,
LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM,
......@@ -37,6 +38,7 @@ enum llm_arch {
LLM_ARCH_PHI3,
LLM_ARCH_PHIMOE,
LLM_ARCH_PLAMO,
LLM_ARCH_PLAMO2,
LLM_ARCH_CODESHELL,
LLM_ARCH_ORION,
LLM_ARCH_INTERNLM2,
......@@ -45,8 +47,12 @@ enum llm_arch {
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3,
LLM_ARCH_GEMMA3N,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_MAMBA2,
LLM_ARCH_JAMBA,
LLM_ARCH_FALCON_H1,
LLM_ARCH_XVERSE,
LLM_ARCH_COMMAND_R,
LLM_ARCH_COHERE2,
......@@ -60,23 +66,38 @@ enum llm_arch {
LLM_ARCH_DEEPSEEK2,
LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_GLM4_MOE,
LLM_ARCH_BITNET,
LLM_ARCH_T5,
LLM_ARCH_T5ENCODER,
LLM_ARCH_JAIS,
LLM_ARCH_NEMOTRON,
LLM_ARCH_EXAONE,
LLM_ARCH_EXAONE4,
LLM_ARCH_RWKV6,
LLM_ARCH_RWKV6QWEN2,
LLM_ARCH_RWKV7,
LLM_ARCH_ARWKV7,
LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE,
LLM_ARCH_GRANITE_HYBRID,
LLM_ARCH_CHAMELEON,
LLM_ARCH_SOLAR,
LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_PLM,
LLM_ARCH_BAILINGMOE,
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
LLM_ARCH_ERNIE4_5,
LLM_ARCH_ERNIE4_5_MOE,
LLM_ARCH_HUNYUAN_MOE,
LLM_ARCH_HUNYUAN_DENSE,
LLM_ARCH_SMOLLM3,
LLM_ARCH_OPENAI_MOE,
LLM_ARCH_LFM2,
LLM_ARCH_DREAM,
LLM_ARCH_SMALLTHINKER,
LLM_ARCH_LLADA,
LLM_ARCH_UNKNOWN,
};
......@@ -113,6 +134,7 @@ enum llm_kv {
LLM_KV_EXPERT_WEIGHTS_NORM,
LLM_KV_EXPERT_GATING_FUNC,
LLM_KV_MOE_EVERY_N_LAYERS,
LLM_KV_NEXTN_PREDICT_LAYERS,
LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID,
......@@ -170,6 +192,7 @@ enum llm_kv {
LLM_KV_SSM_CONV_KERNEL,
LLM_KV_SSM_STATE_SIZE,
LLM_KV_SSM_TIME_STEP_RANK,
LLM_KV_SSM_GROUP_COUNT,
LLM_KV_SSM_DT_B_C_RMS,
LLM_KV_WKV_HEAD_SIZE,
......@@ -192,13 +215,13 @@ enum llm_kv {
LLM_KV_TOKENIZER_MASK_ID,
LLM_KV_TOKENIZER_ADD_BOS,
LLM_KV_TOKENIZER_ADD_EOS,
LLM_KV_TOKENIZER_ADD_SEP,
LLM_KV_TOKENIZER_ADD_PREFIX,
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV,
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
LLM_KV_TOKENIZER_FIM_PRE_ID,
LLM_KV_TOKENIZER_FIM_SUF_ID,
LLM_KV_TOKENIZER_FIM_MID_ID,
......@@ -215,6 +238,10 @@ enum llm_kv {
LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
LLM_KV_CONVNEXT_BLOCK_COUNT,
LLM_KV_CLASSIFIER_OUTPUT_LABELS,
LLM_KV_SHORTCONV_L_CACHE,
// deprecated:
LLM_KV_TOKENIZER_PREFIX_ID,
LLM_KV_TOKENIZER_SUFFIX_ID,
......@@ -241,6 +268,7 @@ enum llm_tensor {
LLM_TENSOR_ATTN_OUT_NORM,
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_ROT_EMBD,
LLM_TENSOR_ATTN_SINKS,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_NORM,
......@@ -265,12 +293,32 @@ enum llm_tensor {
LLM_TENSOR_LAYER_OUT_NORM,
LLM_TENSOR_POST_ATTN_NORM,
LLM_TENSOR_POST_MLP_NORM,
LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n
LLM_TENSOR_PER_LAYER_MODEL_PROJ, // gemma3n
LLM_TENSOR_PER_LAYER_INP_GATE, // gemma3n
LLM_TENSOR_PER_LAYER_PROJ, // gemma3n
LLM_TENSOR_PER_LAYER_PROJ_NORM, // gemma3n
LLM_TENSOR_PER_LAYER_POST_NORM, // gemma3n
LLM_TENSOR_ALTUP_PROJ, // gemma3n
LLM_TENSOR_ALTUP_UNEMBD_PROJ, // gemma3n
LLM_TENSOR_ALTUP_CORRECT_COEF, // gemma3n
LLM_TENSOR_ALTUP_CORRECT_SCALE, // gemma3n
LLM_TENSOR_ALTUP_PREDICT_COEF, // gemma3n
LLM_TENSOR_ALTUP_ROUTER, // gemma3n
LLM_TENSOR_ALTUP_ROUTER_NORM, // gemma3n
LLM_TENSOR_LAUREL_L, // gemma3n
LLM_TENSOR_LAUREL_R, // gemma3n
LLM_TENSOR_LAUREL_POST_NORM, // gemma3n
LLM_TENSOR_SSM_IN,
LLM_TENSOR_SSM_CONV1D,
LLM_TENSOR_SSM_X,
LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_DT_NORM,
LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_B_NORM,
LLM_TENSOR_SSM_C_NORM,
LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
LLM_TENSOR_TIME_MIX_W0,
LLM_TENSOR_TIME_MIX_W1,
......@@ -365,6 +413,15 @@ enum llm_tensor {
LLM_TENSOR_POS_NET_ATTN_K,
LLM_TENSOR_POS_NET_ATTN_V,
LLM_TENSOR_POS_NET_ATTN_OUT,
LLM_TENSOR_SHORTCONV_CONV,
LLM_TENSOR_SHORTCONV_INPROJ,
LLM_TENSOR_SHORTCONV_OUTPROJ,
LLM_TENSOR_NEXTN_EH_PROJ,
LLM_TENSOR_NEXTN_EMBED_TOKENS,
LLM_TENSOR_NEXTN_ENORM,
LLM_TENSOR_NEXTN_HNORM,
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
};
enum llm_tensor_layer {
......@@ -438,3 +495,7 @@ const char * llm_arch_name(llm_arch arch);
llm_arch llm_arch_from_string(const std::string & name);
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
bool llm_arch_is_recurrent(const llm_arch & arch);
bool llm_arch_is_hybrid (const llm_arch & arch);
bool llm_arch_is_diffusion(const llm_arch & arch);
#include "llama-batch.h"
#include "llama-impl.h"
#include "llama-vocab.h"
#include "llama-memory.h"
#include <cassert>
#include <cstring>
#include <algorithm>
#include <sstream>
llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
// clear empty sequences
// the previous ubatch is assumed to be gone,
// so nothing should refer to values in these sequences anymore.
for (size_t i = seq.size(); i-- > 0;) {
if (seq[i].length == 0) {
seq.pop_back();
} else {
break;
}
llama_batch_allocr::llama_batch_allocr(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
seq_pos.resize(LLAMA_MAX_SEQ);
seq_cpl.resize(LLAMA_MAX_SEQ);
for (auto & cur : seq_cpl) {
cur.resize(LLAMA_MAX_SEQ);
}
ubatch_token.resize(!has_embd ? n_ubatch : 0);
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
ubatch_pos.resize(n_ubatch);
ubatch_n_seq_id.resize(n_ubatch);
ubatch_seq_id.resize(n_ubatch);
ubatch_output.resize(n_ubatch);
llama_ubatch ubatch = {
/*equal_seqs =*/ true,
/*n_tokens =*/ 0,
/*n_seq_tokens =*/ 0,
/*n_seqs =*/ 0,
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
/*pos =*/ ubatch_pos.data(),
/*n_seq_id =*/ ubatch_n_seq_id.data(),
/*seq_id =*/ ubatch_seq_id.data(),
/*output =*/ ubatch_output.data(),
};
return ubatch;
seq_idx.resize(LLAMA_MAX_SEQ, -1);
}
void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
GGML_ASSERT(batch != nullptr);
GGML_ASSERT(length <= seq.length);
// Can only add sequences of equal lengths to a batch,
// otherwise it isn't clear to which sequence a token belongs
GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
// NOTE: loops are separated for cache-friendliness
if (batch->token) {
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
bool llama_batch_allocr::init(
const llama_batch & batch_inp,
const llama_vocab & vocab,
const llama_memory_i * memory,
uint32_t n_embd,
uint32_t n_seq_max,
bool output_all) {
clear();
batch = batch_inp;
this->vocab = &vocab;
GGML_ASSERT(batch.n_tokens > 0);
//
// validate input batch
//
if (n_seq_max > LLAMA_MAX_SEQ) {
LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
return false;
}
if (batch.token) {
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return false;
}
} else {
// simple split
ubatch.token = batch->token + seq.offset;
}
} else {
ubatch.token = nullptr;
}
if (batch->embd) {
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
memcpy(
ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
batch->embd + (n_embd * ids[seq.offset + i]),
n_embd * sizeof(float)
);
}
if (batch.seq_id) {
for (int32_t i = 0; i < batch.n_tokens; ++i) {
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d >= %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
return false;
}
}
} else {
// simple split
ubatch.embd = batch->embd + (n_embd * seq.offset);
}
} else {
ubatch.embd = nullptr;
}
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
//
// auto-generate missing fields
//
if (!batch.n_seq_id) {
n_seq_id.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
n_seq_id[i] = seq_id_0.size();
}
} else {
// simple split
ubatch.pos = batch->pos + seq.offset;
batch.n_seq_id = n_seq_id.data();
}
if (ubatch.equal_seqs) {
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
if (seq.seq_id) {
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
if (!batch.seq_id) {
seq_id.resize(batch.n_tokens + 1);
seq_id[batch.n_tokens] = NULL;
for (int32_t i = 0; i < batch.n_tokens; i++) {
seq_id[i] = seq_id_0.data();
}
} else {
// simple split
if (batch->n_seq_id) {
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
} else {
for (size_t i = 0; i < length; ++i) {
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
batch.seq_id = seq_id.data();
}
if (!batch.pos) {
pos.resize(batch.n_tokens);
// initialize the starting position for each sequence based on the positions in the memory
llama_pos p0[LLAMA_MAX_SEQ];
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (!memory) {
// if no memory -> start from 0
p0[s] = 0;
} else {
p0[s] = memory->seq_pos_max(s) + 1;
}
}
if (batch->seq_id) {
ubatch.seq_id = batch->seq_id + seq.offset;
for (int32_t i = 0; i < batch.n_tokens; i++) {
const llama_seq_id seq_id = batch.seq_id[i][0];
pos[i] = p0[seq_id];
// update the starting position for all sequences that are assigned to the this token
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
const llama_seq_id seq_id = batch.seq_id[i][s];
p0[seq_id] = pos[i] + 1;
}
}
batch.pos = pos.data();
}
if (logits_all) {
for (size_t i = 0; i < length; ++i) {
ubatch.output[ubatch.n_tokens + i] = 1;
out_ids.push_back(ids[seq.offset + i]);
if (!batch.logits) {
if (output_all) {
// return the output for all tokens
output.resize(batch.n_tokens, true);
} else {
// return the output only for the last token
output.resize(batch.n_tokens, false);
output[output.size() - 1] = true;
}
} else if (batch->logits) {
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
size_t id = ids[seq.offset + i];
int8_t is_output = batch->logits[id];
ubatch.output[ubatch.n_tokens + i] = is_output;
if (is_output) { out_ids.push_back(id); }
batch.logits = output.data();
} else if (output_all) {
bool warn = false;
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (batch.logits[i] == 0) {
warn = true;
}
} else {
// simple split
ubatch.output = batch->logits + seq.offset;
for (size_t i = 0; i < length; ++i) {
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
}
if (warn) {
LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
output.resize(batch.n_tokens, true);
batch.logits = output.data();
}
}
//
// compute stats
//
this->n_embd = n_embd;
this->n_seq_max = n_seq_max;
// count the outputs in this batch
for (int32_t i = 0; i < batch.n_tokens; ++i) {
n_outputs += batch.logits[i] != 0;
}
has_cpl = false;
// determine coupled sequences
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
for (int32_t i = 0; i < batch.n_tokens; ++i) {
const llama_seq_id s0 = batch.seq_id[i][0];
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
const llama_seq_id s1 = batch.seq_id[i][s];
seq_pos[s1].insert(batch.pos[i]);
if (s > 0) {
// mark that sequence s1 is coupled to s0
seq_cpl[s1][s0] = true;
// note: tracking the other way around is not necessary for now
//seq_cpl[s0][s1] = true;
has_cpl = true;
}
}
} else {
// only get last output
for (size_t i = 0; i < length; ++i) {
size_t id = ids[seq.offset + i];
int8_t is_last = id == ids.size() - 1;
ubatch.output[ubatch.n_tokens + i] = is_last;
if (is_last) { out_ids.push_back(id); }
}
}
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
}
ubatch.n_tokens += length;
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
seq.offset += length;
seq.length -= length;
n_tokens -= length;
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
}
}
llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
ubatch.equal_seqs = false;
if (!seq.empty()) {
llama_sbatch_seq & s = seq[0];
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
add_seq_to_ubatch(ubatch, s, length);
}
return ubatch;
}
// precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
{
seq_set_t seq_set_unq;
for (int32_t i = 0; i < batch.n_tokens; ++i) {
seq_set_t cur;
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
const llama_seq_id seq_id = batch.seq_id[i][s];
llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
if (!seq.empty()) {
size_t length = 0;
size_t n_tokens_in_ubatch = 0;
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
// smallest first, because it's easier to split this way;
// starting from the end to pop in constant time.
for (size_t i = seq.size(); i-- > 0;) {
llama_sbatch_seq & s = seq[i];
GGML_ASSERT(s.length > 0);
if (length == 0) {
length = s.length < n_ubatch ? s.length : n_ubatch;
cur .set(seq_id);
seq_set_unq.set(seq_id);
}
seq_set.push_back(cur);
seq_set_map[cur].push_back(i);
}
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_set_unq.test(s)) {
seq_idx[s] = seq_id_unq.size();
seq_id_unq.push_back(s);
}
add_seq_to_ubatch(ubatch, s, length);
n_tokens_in_ubatch += length;
// shared prompts can't be mixed with any of their sequences,
// so it's safer to compute them in their own ubatch
if (s.n_seq_id > 1) { break; }
// stop when there isn't enough space for another sequence
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
}
}
return ubatch;
}
llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
if (!seq.empty()) {
llama_sbatch_seq & s = seq[seq.size() - 1];
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
add_seq_to_ubatch(ubatch, s, length);
if (debug > 0) {
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
llama_ubatch ubatch {
/*.b_equal_seqs =*/ false,
/*.n_tokens =*/ (uint32_t) batch.n_tokens,
/*.n_seq_tokens =*/ (uint32_t) 1,
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
/*.token =*/ batch.token,
/*.embd =*/ batch.embd,
/*.pos =*/ batch.pos,
/*.n_seq_id =*/ batch.n_seq_id,
/*.seq_id =*/ batch.seq_id,
/*.seq_id_unq =*/ this->seq_id_unq.data(),
/*.seq_idx =*/ this->seq_idx.data(),
/*.output =*/ batch.logits,
/*.data =*/ {},
};
ubatch_print(ubatch, debug);
LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
if (seq_pos[s0].empty()) {
continue;
}
std::stringstream ss;
for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
if (seq_cpl[s0][s1]) {
ss << s1 << " ";
}
}
LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
__func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
}
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
}
return ubatch;
}
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
GGML_ASSERT(batch.n_tokens >= 0);
this->batch = &batch;
this->n_embd = n_embd;
this->logits_all = logits_all;
//
// consistency checks
//
n_tokens = batch.n_tokens;
ids.resize(n_tokens);
out_ids.clear();
// TODO: reserve out_ids and seq
for (size_t i = 0; i < n_tokens; ++i) {
ids[i] = i;
}
if (simple_split) {
seq.resize(1);
llama_sbatch_seq & s = seq[0];
s.n_seq_id = 0;
s.seq_id = nullptr;
s.offset = 0;
s.length = n_tokens;
return;
}
std::sort(ids.begin(), ids.end(),
[&batch](size_t a, size_t b) {
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
// sort by seq_id, then by pos
if (n_seq_a == n_seq_b) {
if (batch.seq_id) {
for (int32_t i = 0; i < n_seq_a; ++i) {
llama_seq_id seq_id_a = batch.seq_id[a][i];
llama_seq_id seq_id_b = batch.seq_id[b][i];
// smaller seq_ids go first
if (seq_id_a != seq_id_b) {
return seq_id_a < seq_id_b;
}
}
}
// when all else is equal, sort by pos
if (batch.pos) {
return batch.pos[a] < batch.pos[b];
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
}
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
if (p0 >= 0) {
bool ok = true;
if (batch.token) {
if (seq_pos_min(s) != p0 + 1) {
ok = false;
}
} else {
assert(batch.embd);
// for embeddings (typically used as vision input), we allow them to have repeating positions
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
ok = false;
}
}
if (!ok) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
__func__, s, s, p0, s, seq_pos_min(s));
return false;
}
}
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
return false;
}
}
if (memory) {
for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
if (seq_cpl[s0][s1]) {
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
return false;
}
// no pos, sort by id
return a < b;
}
// shared prompts go first
return n_seq_a > n_seq_b;
}
);
// init seq
llama_sbatch_seq * last_seq = nullptr;
for (size_t i = 0; i < n_tokens; ++i) {
const size_t bi = ids[i];
const int32_t n_seqs = batch.n_seq_id[bi];
llama_seq_id * seq_ids = batch.seq_id[bi];
if (last_seq != nullptr) {
bool same = n_seqs == last_seq->n_seq_id;
for (int32_t j = 0; same && j < n_seqs; ++j) {
if (seq_ids[j] != last_seq->seq_id[j]) {
same = false;
}
}
// disallow partial sequence sub-sets:
//
// invalid: x
// i: 0 1 2 ...
// ---------------------------------------
// seq_id[i][0]: 0 0 1
// seq_id[i][1]: 1 1 2
// seq_id[i][2]: 2
//
// disallow decreasing sequence positions:
//
// invalid: x
// i: 0 1 2 3 4 5 6 ...
// ---------------------------------------
// pos[i]: 4 5 0 1 6 2 3
// seq_id[i][0]: 0 0 1 1 0 1 0
//
{
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
for (uint32_t s = 0; s < n_seq_max; ++s) {
cur_seq_set[s].set();
}
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
for (uint32_t s = 0; s < n_seq_max; ++s) {
cur_seq_pos[s] = -1;
}
for (int32_t i = 0; i < batch.n_tokens; ++i) {
const llama_pos pos = batch.pos[i];
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
const llama_seq_id seq_id = batch.seq_id[i][s];
cur_seq_set[seq_id] &= seq_set[i];
if (cur_seq_set[seq_id].none()) {
LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets (not allowed)\n", __func__, seq_id);
return false;
}
if (pos < cur_seq_pos[seq_id]) {
LLAMA_LOG_ERROR("%s: sequence %d positions are decreasing (not allowed)\n", __func__, seq_id);
return false;
}
}
if (same) {
last_seq->length += 1;
continue;
}
}
split_reset();
return true;
}
llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs) {
const uint32_t n_tokens = n_seq_tokens*n_seqs;
clear();
split_reset();
auto udata = std::make_shared<llama_ubatch::data_t>();
udata->token .resize(n_tokens);
udata->embd .clear();
udata->pos .resize(n_tokens);
udata->n_seq_id .resize(n_tokens);
udata->seq_id .resize(n_tokens);
udata->seq_id_unq.resize(0);
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
udata->output .resize(n_tokens);
for (uint32_t s = 0; s < n_seqs; ++s) {
udata->seq_idx[s] = s;
udata->seq_id_unq.push_back(s);
}
llama_ubatch res {
/*.b_equal_seqs =*/ true,
/*.n_tokens =*/ n_tokens,
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.n_seqs_unq =*/ n_seqs,
/*.token =*/ udata->token.data(),
/*.embd =*/ nullptr,
/*.pos =*/ udata->pos.data(),
/*.n_seq_id =*/ udata->n_seq_id.data(),
/*.seq_id =*/ udata->seq_id.data(),
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
/*.seq_idx =*/ udata->seq_idx.data(),
/*.output =*/ udata->output.data(),
/*.data =*/ std::move(udata),
};
return res;
}
const llama_batch & llama_batch_allocr::get_batch() const {
return batch;
}
uint32_t llama_batch_allocr::get_n_tokens() const {
return batch.n_tokens;
}
uint32_t llama_batch_allocr::get_n_outputs() const {
return n_outputs;
}
uint32_t llama_batch_allocr::get_n_used() const {
return n_used;
}
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
return out_ids;
}
llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
}
llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
}
void llama_batch_allocr::split_reset() {
out_ids.clear();
n_used = 0;
used.clear();
used.resize(get_n_tokens(), false);
}
llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
// find the first unused token
uint32_t cur_idx = 0;
while (cur_idx < used.size() && used[cur_idx]) {
++cur_idx;
}
// we are done
if (cur_idx >= used.size()) {
return {};
}
std::vector<int32_t> idxs;
while (true) {
idxs.push_back(cur_idx);
used[cur_idx] = true;
++n_used;
++cur_idx;
if (cur_idx >= used.size()) {
break;
}
if (idxs.size() >= n_ubatch) {
break;
}
}
return ubatch_add(idxs, idxs.size(), false);
}
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
if (sequential && has_cpl) {
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
return {};
}
std::vector<seq_set_t> cur_seq_set;
llama_seq_id last_seq_id = -1;
// determine the non-overlapping sequence sets participating in this ubatch
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (used[i]) {
continue;
}
bool add = true;
for (uint32_t s = 0; s < cur_seq_set.size(); ++s) {
// no overlap with existing sequence sets:
if (!(cur_seq_set[s] & seq_set[i]).none()) {
add = false;
break;
}
}
// accept only increasing sequence ids
if (sequential) {
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
}
if (add) {
cur_seq_set.push_back(seq_set[i]);
last_seq_id = batch.seq_id[i][0];
if (cur_seq_set.size() > n_ubatch) {
break;
}
}
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
seq.push_back(new_seq);
last_seq = &seq.back();
}
// keep shared prompts first at the end, then sort by length descending.
std::sort(seq.begin(), seq.end(),
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
if (a.n_seq_id == b.n_seq_id) {
return a.length > b.length;
}
return a.n_seq_id < b.n_seq_id;
const uint32_t n_seqs = cur_seq_set.size();
// we are done
if (n_seqs == 0) {
return {};
}
// the current batch index of each sequence set
std::vector<int32_t> cur_idx(n_seqs, 0);
for (uint32_t s = 0; s < n_seqs; ++s) {
while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
++cur_idx[s];
}
}
// the list of batch indices for each sequence set
// at the end we will concat these to get the final ubatch
std::vector<idx_vec_t> idxs_per_seq(n_seqs);
while (true) {
// we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
// if we haven't reached n_ubatch
bool can_expand = true;
for (uint32_t s = 0; s < n_seqs; ++s) {
if (cur_idx[s] >= (int32_t) seq_set_map[cur_seq_set[s]].size()) {
can_expand = false;
break;
}
);
}
if (!can_expand) {
break;
}
for (uint32_t s = 0; s < n_seqs; ++s) {
const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
idxs_per_seq[s].push_back(idx);
used[idx] = true;
++n_used;
++cur_idx[s];
}
if ((idxs_per_seq[0].size() + 1)*n_seqs > n_ubatch) {
break;
}
}
// concat the per-sequence-set lists
std::vector<int32_t> idxs;
for (uint32_t s = 0; s < n_seqs; ++s) {
idxs.insert(idxs.end(), idxs_per_seq[s].begin(), idxs_per_seq[s].end());
}
return ubatch_add(idxs, n_seqs, true);
}
llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
batch = in_batch;
GGML_ASSERT(batch.n_tokens > 0);
if (!batch.pos) {
pos.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
pos[i] = i + p0;
llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
// find the first unused token
uint32_t cur_idx = 0;
while (cur_idx < used.size() && used[cur_idx]) {
++cur_idx;
}
// we are done
if (cur_idx >= used.size()) {
return {};
}
// this is the starting sequence set
// we allow adding tokens only if their sequence set is a subset of the current sequence set
auto cur_seq_set = seq_set[cur_idx];
std::vector<int32_t> idxs;
while (true) {
idxs.push_back(cur_idx);
used[cur_idx] = true;
++n_used;
if (idxs.size() >= n_ubatch) {
break;
}
batch.pos = pos.data();
do {
++cur_idx;
} while (cur_idx < get_n_tokens() && (used[cur_idx] || ((cur_seq_set & seq_set[cur_idx]) != seq_set[cur_idx])));
if (cur_idx == get_n_tokens()) {
break;
}
cur_seq_set = seq_set[cur_idx];
}
if (!batch.n_seq_id) {
n_seq_id.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
n_seq_id[i] = seq_id_0.size();
return ubatch_add(idxs, 1, true);
}
void llama_batch_allocr::clear() {
n_outputs = 0;
batch = {};
pos .clear();
n_seq_id .clear();
seq_id .clear();
seq_id_unq.clear();
output .clear();
for (auto & cur : seq_pos) {
cur.clear();
}
for (auto & cur : seq_cpl) {
std::fill(cur.begin(), cur.end(), false);
}
seq_set.clear();
seq_set_map.clear();
std::fill(seq_idx.begin(), seq_idx.end(), -1);
}
llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
const uint32_t n_tokens = idxs.size();
assert(n_tokens%n_seqs == 0);
auto udata = std::make_shared<llama_ubatch::data_t>();
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
udata->token .resize(n_tokens);
udata->embd .resize(n_embd_all);
udata->pos .resize(n_pos_all);
udata->n_seq_id .resize(n_tokens);
udata->seq_id .resize(n_tokens);
udata->seq_id_unq.resize(0);
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
udata->output .resize(n_tokens);
seq_set_t seq_set_unq;
for (size_t i = 0; i < idxs.size(); ++i) {
if (batch.token) {
udata->token[i] = batch.token[idxs[i]];
}
if (batch.embd) {
memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
}
for (int j = 0; j < n_pos_cur; ++j) {
udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
}
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
udata->seq_id[i] = batch.seq_id[idxs[i]];
udata->output[i] = batch.logits[idxs[i]];
for (int s = 0; s < udata->n_seq_id[i]; ++s) {
seq_set_unq.set(udata->seq_id[i][s]);
}
if (udata->output[i]) {
out_ids.push_back(idxs[i]);
}
batch.n_seq_id = n_seq_id.data();
}
if (!batch.seq_id) {
seq_id.resize(batch.n_tokens + 1);
seq_id[batch.n_tokens] = NULL;
for (int32_t i = 0; i < batch.n_tokens; i++) {
seq_id[i] = seq_id_0.data();
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_set_unq.test(s)) {
udata->seq_idx[s] = udata->seq_id_unq.size();
udata->seq_id_unq.push_back(s);
}
batch.seq_id = seq_id.data();
}
if (!batch.logits) {
logits.resize(batch.n_tokens);
logits[logits.size() - 1] = true;
batch.logits = logits.data();
llama_ubatch res {
/*.b_equal_seqs =*/ equal_seqs,
/*.n_tokens =*/ n_tokens,
/*.n_seq_tokens =*/ n_tokens/n_seqs,
/*.n_seqs =*/ n_seqs,
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
/*.token =*/ batch.token ? udata->token.data() : nullptr,
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
/*.pos =*/ udata->pos.data(),
/*.n_seq_id =*/ udata->n_seq_id.data(),
/*.seq_id =*/ udata->seq_id.data(),
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
/*.seq_idx =*/ udata->seq_idx.data(),
/*.output =*/ udata->output.data(),
/*.data =*/ std::move(udata),
};
if (debug > 0) {
LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
ubatch_print(res, debug);
}
return res;
}
void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
if (debug > 0) {
LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs());
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
LLAMA_LOG_DEBUG("%s: n_seqs_unq = %d\n", __func__, ubatch.n_seqs_unq);
std::stringstream ss_seq_id_unq;
std::stringstream ss_seq_idx;
ss_seq_id_unq << "[ ";
ss_seq_idx << "[";
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
ss_seq_id_unq << ubatch.seq_id_unq[s] << " ";
}
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (ubatch.seq_idx[s] >= 0) {
ss_seq_idx << ubatch.seq_idx[s]%10;
} else {
ss_seq_idx << ".";
}
}
ss_seq_id_unq << "]";
ss_seq_idx << "]";
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token);
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd);
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos);
LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id);
LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id);
LLAMA_LOG_DEBUG("%s: seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str());
LLAMA_LOG_DEBUG("%s: seq_idx = %s\n", __func__, ss_seq_idx.str().c_str());
LLAMA_LOG_DEBUG("%s: output = %p\n", __func__, (void *) ubatch.output);
LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
if (debug > 1) {
int seq_id_max = 0;
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
seq_id_max = std::max(seq_id_max, ubatch.seq_id[i][s]);
}
}
}
++seq_id_max;
LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
std::vector<int8_t> seq_id(seq_id_max);
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
seq_id[ubatch.seq_id[i][s]] = 1;
}
std::stringstream ss;
for (int s = 0; s < seq_id_max; ++s) {
if (seq_id[s]) {
ss << s%10;
} else {
ss << ".";
}
}
if (ubatch.token) {
LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
__func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
} else {
LLAMA_LOG_DEBUG("%s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
__func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
}
}
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
}
}
}
......@@ -317,25 +820,25 @@ struct llama_batch llama_batch_get_one(
llama_token * tokens,
int32_t n_tokens) {
return {
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens,
/*embd =*/ nullptr,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
/*logits =*/ nullptr,
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens,
/*embd =*/ nullptr,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
/*logits =*/ nullptr,
};
}
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
llama_batch batch = {
/*n_tokens =*/ 0,
/*tokens =*/ nullptr,
/*embd =*/ nullptr,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
/*logits =*/ nullptr,
/*n_tokens =*/ 0,
/*tokens =*/ nullptr,
/*embd =*/ nullptr,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
/*logits =*/ nullptr,
};
if (embd) {
......
......@@ -2,88 +2,159 @@
#include "llama.h"
#include "llama-cparams.h"
#include <array>
#include <vector>
#include <set>
#include <bitset>
#include <memory>
#include <unordered_map>
// very similar to llama_batch,
// but has more metadata about sequences
// keep this struct lightweight
struct llama_ubatch {
bool equal_seqs;
// TODO: whole_seqs for embeddings?
bool equal_seqs() const {
return b_equal_seqs != 0;
}
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
uint32_t n_seq_tokens; // tokens per sequence
uint32_t n_seqs;
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
// otherwise address sanitizer complains
// TODO: whole_seqs for embeddings?
llama_token * token; // [n_tokens]
float * embd; // [n_embd, n_tokens]
llama_pos * pos; // [n_tokens]
int32_t * n_seq_id; // [n_seqs]
llama_seq_id ** seq_id; // [n_seqs]
int8_t * output; // [n_tokens]
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
uint32_t n_seq_tokens; // tokens per sequence set
uint32_t n_seqs; // sequence sets in the ubatch
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
// seq_id_unq: unique sequence ids in the ubatch
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
// used for extracting sequence pooled embeddings
// // size | idx | val
llama_token * token; // [n_tokens] | i | id, token
float * embd; // [n_embd, n_tokens] | i | embd
llama_pos * pos; // [n_tokens] | i | pos
int32_t * n_seq_id; // [n_tokens] | i | -
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
int8_t * output; // [n_tokens] | i | -
struct data_t {
std::vector<llama_token> token;
std::vector<float> embd;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<llama_seq_id> seq_id_unq;
std::vector<int32_t> seq_idx;
std::vector<int8_t> output;
};
// the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
std::shared_ptr<data_t> data;
};
struct llama_sbatch_seq {
int32_t n_seq_id;
// a helper for sanitizing, fulfilling and splitting a batch
class llama_batch_allocr {
public:
llama_batch_allocr(uint32_t n_pos_per_embd);
llama_seq_id * seq_id;
// sanitize and auto-gen missing data in the input batch
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
bool init(
const llama_batch & batch_inp,
const llama_vocab & vocab,
const llama_memory_i * memory,
uint32_t n_embd,
uint32_t n_seq_max,
bool output_all);
size_t offset;
size_t length;
};
const llama_batch & get_batch() const;
// sequence-length-aware batch splitting
struct llama_sbatch {
// tokens left in this batch
size_t n_tokens;
uint32_t get_n_tokens() const;
uint32_t get_n_outputs() const;
uint32_t get_n_used() const;
size_t n_embd;
// the array of output indices in the order they were encountered during the ubatch splitting
std::vector<int32_t> & get_out_ids();
bool logits_all; // TODO: remove once lctx.logits_all is removed too
// min/max positions of each sequence in the current ubatch
llama_pos seq_pos_min(llama_seq_id seq_id) const;
llama_pos seq_pos_max(llama_seq_id seq_id) const;
// sorted indices into the batch
std::vector<int64_t> ids;
// batch indices of the output
std::vector<int64_t> out_ids;
std::vector<llama_sbatch_seq> seq;
// call once before splitting the batch to reset the internal state
void split_reset();
const llama_batch * batch = nullptr;
// simple split, unknown number of sequence sets of unequal lengths
llama_ubatch split_simple(uint32_t n_ubatch);
// buffers for the ubatch
std::vector<llama_token> ubatch_token;
std::vector<float> ubatch_embd;
std::vector<llama_pos> ubatch_pos;
std::vector<int32_t> ubatch_n_seq_id;
std::vector<llama_seq_id *> ubatch_seq_id;
std::vector<int8_t> ubatch_output;
// make ubatches of equal-length sequences sets
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
// sequence-set-wise split - each ubatch contains a single sequence-set
llama_ubatch split_seq(uint32_t n_ubatch);
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
// a helper method for creating a well-defined ubatch of tokens
// TODO: support embeddings if needed in the future
llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
// simple split, unknown number of sequences of unequal lengths
llama_ubatch split_simple(size_t n_ubatch);
private:
void clear();
// make batches of equal-length sequences
llama_ubatch split_equal(size_t n_ubatch);
// create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
// return llama_ubatch.n_tokens == 0 if the entire batch was consumed
llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
// sequence-wise split
llama_ubatch split_seq(size_t n_ubatch);
// for debugging, start with LLAMA_BATCH_DEBUG=2
void ubatch_print(const llama_ubatch & ubatch, int debug);
llama_sbatch() = default;
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
};
llama_batch batch;
// only for debugging purposes
const llama_vocab * vocab;
// TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
const uint32_t n_pos_per_embd;
// temporary allocate memory for the input batch if needed
struct llama_batch_allocr {
struct llama_batch batch;
uint32_t n_embd;
uint32_t n_seq_max;
uint32_t n_outputs;
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> logits;
std::vector<llama_seq_id> seq_id_unq;
std::vector<int32_t> seq_idx;
std::vector<int8_t> output;
using pos_set_t = std::set<llama_pos>;
using seq_cpl_t = std::vector<bool>;
// helper flag to quickly determine if there are any coupled sequences in the batch
bool has_cpl = false;
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
using idx_vec_t = std::vector<int32_t>;
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
// batch indices of the output
std::vector<int32_t> out_ids;
uint32_t n_used;
// used[i] indicates if token i has already been used in a previous ubatch
std::vector<bool> used;
// optionally fulfill the batch returned by llama_batch_get_one
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
int debug;
};
......@@ -56,6 +56,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
{ "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
......@@ -64,6 +65,10 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
{ "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE },
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
};
llm_chat_template llm_chat_template_from_str(const std::string & name) {
......@@ -166,10 +171,13 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
} else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) {
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
if (tmpl_contains("[|tool|]")) {
return LLM_CHAT_TEMPLATE_EXAONE_4;
}
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
// EXAONE-3.0-7.8B-Instruct
return LLM_CHAT_TEMPLATE_EXAONE_3;
} else if (tmpl_contains("rwkv-world")) {
} else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) {
return LLM_CHAT_TEMPLATE_RWKV_WORLD;
} else if (tmpl_contains("<|start_of_role|>")) {
return LLM_CHAT_TEMPLATE_GRANITE;
......@@ -183,6 +191,16 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_BAILING;
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
return LLM_CHAT_TEMPLATE_LLAMA4;
} else if (tmpl_contains("<|endofuserprompt|>")) {
return LLM_CHAT_TEMPLATE_DOTS1;
} else if (tmpl_contains("<|extra_0|>") && tmpl_contains("<|extra_4|>")) {
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
} else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) {
return LLM_CHAT_TEMPLATE_OPENAI_MOE;
} else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) {
return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE;
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
return LLM_CHAT_TEMPLATE_KIMI_K2;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
......@@ -331,7 +349,7 @@ int32_t llm_chat_apply_template(
std::string role(message->role);
if (role == "system") {
// there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
system_prompt = trim(message->content);
system_prompt += trim(message->content);
continue;
}
// in gemma, "assistant" is "model"
......@@ -353,7 +371,7 @@ int32_t llm_chat_apply_template(
std::string role(message->role);
if (role == "system") {
// there is no system message support, we will merge it with user prompt
system_prompt = message->content;
system_prompt += message->content;
continue;
} else if (role == "user") {
ss << "Human: ";
......@@ -524,14 +542,35 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "[|assistant|]";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
// this template requires the model to have "\n\n" as EOT token
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_4) {
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
ss << "User: " << message->content << "\n\nAssistant:";
} else {
ss << message->content << "\n\n";
if (role == "system") {
ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
} else if (role == "user") {
ss << "[|user|]" << trim(message->content) << "\n";
} else if (role == "assistant") {
ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
} else if (role == "tool") {
ss << "[|tool|]" << trim(message->content) << "[|endofturn|]\n";
}
}
if (add_ass) {
ss << "[|assistant|]";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
// this template requires the model to have "\n\n" as EOT token
for (size_t i = 0; i < chat.size(); i++) {
std::string role(chat[i]->role);
if (role == "system") {
ss << "System: " << trim(chat[i]->content) << "\n\n";
} else if (role == "user") {
ss << "User: " << trim(chat[i]->content) << "\n\n";
if (i == chat.size() - 1) {
ss << "Assistant:";
}
} else if (role == "assistant") {
ss << "Assistant: " << trim(chat[i]->content) << "\n\n";
}
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
......@@ -586,8 +625,6 @@ int32_t llm_chat_apply_template(
} else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) {
// Yandex template ("\n\n" is defined as EOT token)
ss << "<s>";
for (size_t i = 0; i < chat.size(); i++) {
std::string role(chat[i]->role);
if (role == "user") {
......@@ -643,6 +680,78 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "Assistant:";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_DOTS1) {
// dots.llm1.inst (DOTS1)
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "<|system|>" << message->content << "<|endofsystem|>";
} else if (role == "user") {
ss << "<|userprompt|>" << message->content << "<|endofuserprompt|>";
} else {
ss << "<|response|>" << message->content << "<|endofresponse|>";
}
}
if (add_ass) {
ss << "<|response|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
// tencent/Hunyuan-A13B-Instruct
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "<|startoftext|>" << message->content << "<|extra_4|>";
} else if (role == "assistant") {
ss << message->content << "<|eos|>";
} else {
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
}
}
} else if (tmpl == LLM_CHAT_TEMPLATE_OPENAI_MOE) {
// OpenAI MoE (based on Harmony chat template)
for (auto message : chat) {
std::string role(message->role);
ss << "<|start|>" << role << "<|message|>" << message->content;
ss << (role == "assistant" ? "<|return|>" : "<|end|>");
}
if (add_ass) {
ss << "<|start|>assistant";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_DENSE) {
// tencent/Hunyuan-4B-Instruct
for (size_t i = 0; i < chat.size(); i++) {
std::string role(chat[i]->role);
if (i == 0) {
if (role == "system") {
ss << chat[i]->content << "<|hy_place▁holder▁no▁3|>";
}
}
if (role == "assistant") {
ss << "<|hy_Assistant|>" << chat[i]->content << "<|hy_place▁holder▁no▁2|>";
} else if (role == "user") {
ss << "<|hy_User|>" << chat[i]->content << "<|hy_Assistant|>";
}
}
} else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) {
// moonshotai/Kimi-K2-Instruct
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "<|im_system|>system<|im_middle|>";
} else if (role == "user") {
ss << "<|im_user|>user<|im_middle|>";
} else if (role == "assistant") {
ss << "<|im_assistant|>assistant<|im_middle|>";
} else if (role == "tool") {
ss << "<|im_system|>tool<|im_middle|>";
}
ss << message->content << "<|im_end|>";
}
if (add_ass) {
ss << "<|im_assistant|>assistant<|im_middle|>";
}
} else {
// template not supported
return -1;
......
......@@ -35,6 +35,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_GLMEDGE,
LLM_CHAT_TEMPLATE_MINICPM,
LLM_CHAT_TEMPLATE_EXAONE_3,
LLM_CHAT_TEMPLATE_EXAONE_4,
LLM_CHAT_TEMPLATE_RWKV_WORLD,
LLM_CHAT_TEMPLATE_GRANITE,
LLM_CHAT_TEMPLATE_GIGACHAT,
......@@ -43,6 +44,11 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_BAILING,
LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_DOTS1,
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
LLM_CHAT_TEMPLATE_OPENAI_MOE,
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
LLM_CHAT_TEMPLATE_KIMI_K2,
LLM_CHAT_TEMPLATE_UNKNOWN,
};
......
#include "llama-context.h"
#include "llama-impl.h"
#include "llama-batch.h"
#include "llama-io.h"
#include "llama-memory.h"
#include "llama-mmap.h"
#include "llama-model.h"
#include "llama-kv-cache.h"
#include <cinttypes>
#include <cstring>
#include <limits>
#include <stdexcept>
#include <cinttypes>
//
// llama_context
......@@ -17,7 +19,8 @@
llama_context::llama_context(
const llama_model & model,
llama_context_params params) :
model(model) {
model(model),
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
t_start_us = model.t_start_us;
......@@ -25,7 +28,11 @@ llama_context::llama_context(
const auto & hparams = model.hparams;
cparams.n_seq_max = std::max(1u, params.n_seq_max);
cparams.n_seq_max = std::max(1u, params.n_seq_max);
if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
}
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
......@@ -91,9 +98,29 @@ llama_context::llama_context(
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
cparams.n_batch = GGML_KQ_MASK_PAD;
}
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
cparams.op_offload = params.op_offload;
cparams.kv_unified = params.kv_unified;
{
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows;
if (!supports_set_rows && !cparams.kv_unified) {
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
cparams.kv_unified = true;
}
}
{
const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
if (graph_reuse_disable) {
LLAMA_LOG_WARN("%s: graph reuse disabled\n", __func__);
}
}
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
......@@ -104,6 +131,7 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
......@@ -117,6 +145,11 @@ llama_context::llama_context(
__func__, n_ctx_per_seq, hparams.n_ctx_train);
}
if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
__func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
}
if (!hparams.vocab_only) {
// GPU backends
for (auto * dev : model.devices) {
......@@ -176,8 +209,9 @@ llama_context::llama_context(
// init the memory module
if (!hparams.vocab_only) {
llama_memory_params params_mem = {
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
/*.swa_full =*/ params.swa_full,
};
memory.reset(model.create_memory(params_mem, cparams));
......@@ -213,8 +247,8 @@ llama_context::llama_context(
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
// buffer used to store the computation graph and the tensor meta data
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
gf_res_prev.reset(new llm_graph_result(max_nodes));
gf_res_reserve.reset(new llm_graph_result(max_nodes));
// TODO: move these checks to ggml_backend_sched
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
......@@ -253,15 +287,9 @@ llama_context::llama_context(
// reserve worst-case graph
if (!hparams.vocab_only && memory) {
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
// restore later
// TODO: something cleaner
const auto n_outputs_save = n_outputs;
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
int n_splits_pp = -1;
......@@ -271,25 +299,18 @@ llama_context::llama_context(
int n_nodes_tg = -1;
// simulate full KV cache
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->set_full();
const auto mctx = memory->init_full();
if (!mctx) {
throw std::runtime_error("failed to initialize KV cache");
}
cross.v_embd.clear();
// reserve pp graph first so that buffers are only allocated once
// reserve pp (prompt processing) graph first so that buffers are only allocated once
{
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
// max number of outputs
n_outputs = ubatch_pp.n_tokens;
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
auto * gf = graph_init();
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
......@@ -297,18 +318,10 @@ llama_context::llama_context(
n_nodes_pp = ggml_graph_n_nodes(gf);
}
// reserve with tg graph to get the number of splits and nodes
// reserve with tg (token generation) graph to get the number of splits and nodes
{
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
n_outputs = ubatch_tg.n_tokens;
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
auto * gf = graph_init();
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers");
}
......@@ -318,22 +331,16 @@ llama_context::llama_context(
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
{
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
n_outputs = ubatch_pp.n_tokens;
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
auto * gf = graph_init();
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
//
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
//
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
}
n_outputs = n_outputs_save;
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
ggml_backend_t backend = backend_ptrs[i];
ggml_backend_buffer_type_t buft = backend_buft[i];
......@@ -405,10 +412,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
return sched.get();
}
ggml_context * llama_context::get_ctx_compute() const {
return ctx_compute.get();
}
uint32_t llama_context::n_ctx() const {
return cparams.n_ctx;
}
......@@ -437,46 +440,76 @@ uint32_t llama_context::n_threads_batch() const {
return cparams.n_threads_batch;
}
llama_kv_cache * llama_context::get_kv_self() {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
return kv_self;
llama_memory_t llama_context::get_memory() const {
return memory.get();
}
const llama_kv_cache * llama_context::get_kv_self() const {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
return kv_self;
}
// deprecated
void llama_context::kv_self_defrag_sched() {
if (!memory) {
return;
}
void llama_context::kv_self_update() {
bool need_reserve = false;
memory_force_optimize = true;
}
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
// deprecated
bool llama_context::kv_self_update(bool optimize) {
if (!memory) {
return false;
}
need_reserve = kv_self->update(*this);
{
// TODO: remove in the future
optimize |= memory_force_optimize;
memory_force_optimize = false;
// reserve a worst case graph if needed
if (need_reserve) {
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
const auto mctx = memory->init_update(this, optimize);
switch (mctx->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
// noop
} break;
case LLAMA_MEMORY_STATUS_NO_UPDATE:
{
// no updates need to be performed
return false;
}
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{
LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
return false;
}
}
// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
// reset the previous graph result to make sure that it won't be reused
// TODO: change the mctx->apply() to return information if a graph reserve is needed
// reset the graph result only if the memory module did reset the scheduler
gf_res_prev->reset();
// simulate full KV cache
kv_self->set_full();
if (!mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
}
}
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
// if the memory module did any computation, we have to reserve a new worst-case graph
{
const auto mctx = memory->init_full();
if (!mctx) {
throw std::runtime_error("failed to initialize memory context");
}
auto * gf = graph_init();
graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(sched.get());
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) {
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
}
}
return true;
}
enum llama_pooling_type llama_context::pooling_type() const {
......@@ -484,11 +517,15 @@ enum llama_pooling_type llama_context::pooling_type() const {
}
float * llama_context::get_logits() {
output_reorder();
return logits;
}
float * llama_context::get_logits_ith(int32_t i) {
int32_t j = -1;
int64_t j = -1;
output_reorder();
try {
if (logits == nullptr) {
......@@ -511,7 +548,7 @@ float * llama_context::get_logits_ith(int32_t i) {
}
if (j >= n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
}
return logits + j*model.vocab.n_tokens();
......@@ -526,11 +563,15 @@ float * llama_context::get_logits_ith(int32_t i) {
}
float * llama_context::get_embeddings() {
output_reorder();
return embd;
}
float * llama_context::get_embeddings_ith(int32_t i) {
int32_t j = -1;
int64_t j = -1;
output_reorder();
try {
if (embd == nullptr) {
......@@ -553,7 +594,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
}
if (j >= n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
}
return embd + j*model.hparams.n_embd;
......@@ -670,66 +711,119 @@ bool llama_context::apply_adapter_cvec(
return cvec.apply(model, data, len, n_embd, il_start, il_end);
}
int llama_context::encode(llama_batch & inp_batch) {
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
if (mctx && !mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}
// temporary allocate memory for the input batch if needed
// note: during encode, we always pass the full sequence starting from pos = 0
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
auto * res = gf_res_prev.get();
auto * gf = res->get_gf();
const llama_batch & batch = batch_allocr.batch;
const int32_t n_tokens = batch.n_tokens;
// the new graph parameters
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
const auto gparams = graph_params(res, ubatch, mctx, gtype);
const auto & hparams = model.hparams;
if (!graph_reuse_disable && res->can_reuse(gparams)) {
//LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
n_reused++;
} else {
res->reset();
if (batch.token) {
for (int32_t i = 0; i < n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
//const auto t_start_us = ggml_time_us();
gf = model.build_graph(gparams);
//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
if (!gf) {
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
ret = GGML_STATUS_ALLOC_FAILED;
return nullptr;
}
}
// set the input data for the input tensors
{
//const auto t_start_us = ggml_time_us();
res->set_inputs(&ubatch);
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
}
const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
if (status != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
ret = status;
return nullptr;
}
ret = GGML_STATUS_SUCCESS;
return res;
}
int llama_context::encode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
if (batch_inp.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = model.vocab.n_tokens();
// note: during encode, we always pass the full sequence starting from pos = 0
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
const uint32_t n_tokens = balloc->get_n_tokens();
// [TAG_NO_CACHE_PAD]
// TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
if (t_compute_start_us == 0) {
t_compute_start_us = ggml_time_us();
}
// TODO: this clear of the buffer can easily be forgotten - need something better
embd_seq.clear();
n_queued_tokens += n_tokens;
const int64_t n_embd = hparams.n_embd;
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
// reserve output buffer
if (output_reserve(n_tokens) < n_tokens) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
return -2;
};
for (int32_t i = 0; i < n_tokens; ++i) {
for (uint32_t i = 0; i < n_tokens; ++i) {
output_ids[i] = i;
}
n_outputs = n_tokens;
//batch_manager->prepare(ubatch);
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
const auto causal_attn_org = cparams.causal_attn;
// always use non-causal attention for encoder graphs
......@@ -737,32 +831,34 @@ int llama_context::encode(llama_batch & inp_batch) {
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
cparams.causal_attn = false;
auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
ggml_backend_sched_alloc_graph(sched.get(), gf);
res->set_inputs(&ubatch);
ggml_status status;
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
cparams.causal_attn = causal_attn_org;
const auto compute_status = graph_compute(gf, n_tokens > 1);
switch (compute_status) {
case GGML_STATUS_SUCCESS:
break;
case GGML_STATUS_ABORTED:
return 2;
case GGML_STATUS_ALLOC_FAILED:
return -2;
case GGML_STATUS_FAILED:
default:
return -3;
if (!res) {
switch (status) {
case GGML_STATUS_ABORTED: return 2;
case GGML_STATUS_ALLOC_FAILED: return -2;
case GGML_STATUS_FAILED: return -3;
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
}
}
auto * t_logits = res->get_logits();
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
// extract logits
if (logits && t_logits) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(logits != nullptr);
ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
}
// extract embeddings
if (t_embd) {
if (embd && t_embd) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
GGML_ASSERT(backend_embd != nullptr);
......@@ -781,31 +877,28 @@ int llama_context::encode(llama_batch & inp_batch) {
{
// extract sequence embeddings
auto & embd_seq_out = embd_seq;
embd_seq_out.clear();
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
for (int32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = ubatch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// extract the rerank score - a single float per sequence
// extract the rerank score - n_cls_out floats per sequence
auto & embd_seq_out = embd_seq;
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(1);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
const uint32_t n_cls_out = hparams.n_cls_out;
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
embd_seq_out[seq_id].resize(n_cls_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
......@@ -815,9 +908,11 @@ int llama_context::encode(llama_batch & inp_batch) {
}
}
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
if (!supports_set_rows) {
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
}
// TODO: hacky solution
if (model.arch == LLM_ARCH_T5 && t_embd) {
......@@ -830,12 +925,16 @@ int llama_context::encode(llama_batch & inp_batch) {
cross.v_embd.resize(cross.n_embd*cross.n_enc);
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
const auto & batch = balloc->get_batch();
// remember the sequence ids used during the encoding - needed for cross attention later
cross.seq_ids_enc.resize(n_tokens);
for (int32_t i = 0; i < n_tokens; i++) {
for (uint32_t i = 0; i < n_tokens; i++) {
cross.seq_ids_enc[i].clear();
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
llama_seq_id seq_id = ubatch.seq_id[i][s];
for (int s = 0; s < batch.n_seq_id[i]; s++) {
const llama_seq_id seq_id = batch.seq_id[i][s];
cross.seq_ids_enc[i].insert(seq_id);
}
}
......@@ -844,43 +943,42 @@ int llama_context::encode(llama_batch & inp_batch) {
return 0;
}
int llama_context::decode(llama_batch & inp_batch) {
int llama_context::decode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
if (!memory) {
LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
return encode(inp_batch);
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
return encode(batch_inp);
}
if (inp_batch.n_tokens == 0) {
if (batch_inp.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
const auto & vocab = model.vocab;
const auto & hparams = model.hparams;
const int32_t n_vocab = vocab.n_tokens();
const int64_t n_vocab = vocab.n_tokens();
const int64_t n_embd = hparams.n_embd;
const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
// when computing embeddings, all tokens are output
const bool output_all = cparams.embeddings;
llama_kv_cache_guard kv_guard(kv_self);
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
const uint32_t n_tokens_all = balloc->get_n_tokens();
const uint32_t n_outputs_all = balloc->get_n_outputs();
if (batch.token) {
for (int64_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
throw std::runtime_error("invalid token");
}
if (output_all) {
// require that all tokens are output
if (n_outputs_all != n_tokens_all) {
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
__func__, n_outputs_all, n_tokens_all);
return -1;
}
}
......@@ -893,49 +991,78 @@ int llama_context::decode(llama_batch & inp_batch) {
}
n_queued_tokens += n_tokens_all;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
// TODO: this clear of the buffer can easily be forgotten - need something better
embd_seq.clear();
output_swaps.clear();
bool did_optimize = false;
int64_t n_outputs_all = 0;
// handle any pending defrags/shifts
kv_self_update(false);
llama_memory_context_ptr mctx;
// count outputs
if (batch.logits) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs_all += batch.logits[i] != 0;
while (true) {
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
if (!mctx) {
return -2;
}
} else if (embd_pooled) {
n_outputs_all = n_tokens_all;
} else {
// keep last output only
n_outputs_all = 1;
}
llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
switch (mctx->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
} break;
case LLAMA_MEMORY_STATUS_NO_UPDATE:
{
LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
return -2;
}
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
{
if (!did_optimize) {
did_optimize = true;
if (kv_self_update(true)) {
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
continue;
}
}
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
return 1;
}
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
return -2;
}
}
break;
}
// reserve output buffer
if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
return -2;
};
// handle any pending defrags/shifts
kv_self_update();
int64_t n_outputs_prev = 0;
while (sbatch.n_tokens > 0) {
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
do {
const auto & ubatch = mctx->get_ubatch();
// count the outputs in this u_batch
// count the outputs in this ubatch
{
int32_t n_outputs_new = 0;
if (n_outputs_all == n_tokens_all) {
n_outputs_new = ubatch.n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
......@@ -945,38 +1072,37 @@ int llama_context::decode(llama_batch & inp_batch) {
n_outputs = n_outputs_new;
}
// find KV slot
if (!kv_self->find_slot(ubatch)) {
kv_self->defrag_sched(-1.0f);
kv_self->update(*this);
if (!kv_self->find_slot(ubatch)) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
return 1;
ggml_status status;
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
llama_pos pos_min[LLAMA_MAX_SEQ];
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
pos_min[s] = std::numeric_limits<llama_pos>::max();
}
}
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
const auto & seq_id = ubatch.seq_id[i][0];
auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
}
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
continue;
}
ggml_backend_sched_alloc_graph(sched.get(), gf);
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
res->set_inputs(&ubatch);
memory->seq_rm(s, pos_min[s], -1);
}
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
if (compute_status != GGML_STATUS_SUCCESS) {
switch (compute_status) {
case GGML_STATUS_ABORTED:
return 2;
case GGML_STATUS_ALLOC_FAILED:
return -2;
case GGML_STATUS_FAILED:
default:
return -3;
switch (status) {
case GGML_STATUS_ABORTED: return 2;
case GGML_STATUS_ALLOC_FAILED: return -2;
case GGML_STATUS_FAILED: return -3;
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
}
}
......@@ -985,7 +1111,7 @@ int llama_context::decode(llama_batch & inp_batch) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}
auto * t_logits = cparams.causal_attn ? res->get_logits() : nullptr;
auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
if (t_embd && res->get_embd_pooled()) {
......@@ -1032,27 +1158,27 @@ int llama_context::decode(llama_batch & inp_batch) {
// extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = embd_seq;
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// extract the rerank score - a single float per sequence
// extract the rerank score - n_cls_out floats per sequence
auto & embd_seq_out = embd_seq;
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(1);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
const uint32_t n_cls_out = hparams.n_cls_out;
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
embd_seq_out[seq_id].resize(n_cls_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
......@@ -1063,23 +1189,20 @@ int llama_context::decode(llama_batch & inp_batch) {
}
n_outputs_prev += n_outputs;
}
// finalize the batch processing
kv_guard.commit();
} while (mctx->next());
// set to total number of outputs in the batch, for use in llama_get_logits_ith
n_outputs = n_outputs_all;
// set output mappings
{
if (n_outputs > 0) {
bool sorted_output = true;
auto & out_ids = sbatch.out_ids;
auto & out_ids = balloc->get_out_ids();
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
for (int64_t i = 0; i < n_outputs_all; ++i) {
for (int64_t i = 0; i < n_outputs; ++i) {
int64_t out_id = out_ids[i];
output_ids[out_id] = i;
if (out_id != i) {
......@@ -1090,35 +1213,29 @@ int llama_context::decode(llama_batch & inp_batch) {
// make the outputs have the same order they had in the user-provided batch
// note: this is mostly relevant for recurrent models atm
if (!sorted_output) {
const uint32_t n_vocab = model.vocab.n_tokens();
const uint32_t n_embd = model.hparams.n_embd;
GGML_ASSERT((size_t) n_outputs == out_ids.size());
// TODO: is there something more efficient which also minimizes swaps?
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
for (int32_t i = 0; i < n_outputs - 1; ++i) {
int32_t j_min = i;
for (int32_t j = i + 1; j < n_outputs; ++j) {
for (uint32_t i = 0; i < n_outputs - 1; ++i) {
uint32_t j_min = i;
for (uint32_t j = i + 1; j < n_outputs; ++j) {
if (out_ids[j] < out_ids[j_min]) {
j_min = j;
}
}
if (j_min == i) { continue; }
std::swap(out_ids[i], out_ids[j_min]);
if (logits_size > 0) {
for (uint32_t k = 0; k < n_vocab; k++) {
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
}
}
if (embd_size > 0) {
for (uint32_t k = 0; k < n_embd; k++) {
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
}
if (j_min == i) {
continue;
}
std::swap(out_ids[i], out_ids[j_min]);
// remember the swaps and apply them lazily upon logits/embeddings access
output_swaps.push_back({ i, j_min });
}
std::fill(output_ids.begin(), output_ids.end(), -1);
for (int32_t i = 0; i < n_outputs; ++i) {
for (uint32_t i = 0; i < n_outputs; ++i) {
output_ids[out_ids[i]] = i;
}
}
......@@ -1127,15 +1244,12 @@ int llama_context::decode(llama_batch & inp_batch) {
// wait for the computation to finish (automatically done when obtaining the model output)
//synchronize();
// decide if we need to defrag the kv cache
if (cparams.defrag_thold > 0.0f) {
kv_self->defrag_sched(cparams.defrag_thold);
if (!supports_set_rows) {
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
}
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
return 0;
}
......@@ -1143,7 +1257,7 @@ int llama_context::decode(llama_batch & inp_batch) {
// output
//
int32_t llama_context::output_reserve(int32_t n_outputs) {
uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto & hparams = model.hparams;
const auto & vocab = model.vocab;
......@@ -1153,9 +1267,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;
// TODO: use a per-batch flag for logits presence instead
bool has_logits = cparams.causal_attn;
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
bool has_logits = true;
bool has_embd = cparams.embeddings;
// TODO: hacky enc-dec support
if (model.arch == LLM_ARCH_T5) {
......@@ -1209,53 +1322,111 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
// set all ids as invalid (negative)
std::fill(output_ids.begin(), output_ids.end(), -1);
this->n_outputs = 0;
this->n_outputs_max = n_outputs_max;
this->n_outputs = 0;
return n_outputs_max;
}
void llama_context::output_reorder() {
const uint64_t n_vocab = model.vocab.n_tokens();
const uint64_t n_embd = model.hparams.n_embd;
for (size_t s = 0; s < output_swaps.size(); ++s) {
const uint64_t i0 = output_swaps[s].i0;
const uint64_t i1 = output_swaps[s].i1;
if (logits_size > 0) {
for (uint64_t k = 0; k < n_vocab; k++) {
std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
}
}
if (embd_size > 0) {
for (uint64_t k = 0; k < n_embd; k++) {
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
}
}
}
output_swaps.clear();
}
//
// graph
//
int32_t llama_context::graph_max_nodes() const {
return std::max<int32_t>(65536, 5*model.n_tensors());
uint32_t llama_context::graph_max_nodes() const {
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
}
ggml_cgraph * llama_context::graph_init() {
ggml_init_params params = {
/*.mem_size =*/ buf_compute_meta.size(),
/*.mem_buffer =*/ buf_compute_meta.data(),
/*.no_alloc =*/ true,
};
llm_graph_result * llama_context::get_gf_res_reserve() const {
return static_cast<llm_graph_result *>(gf_res_reserve.get());
}
ctx_compute.reset(ggml_init(params));
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
}
if (n_tokens % n_seqs != 0) {
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
n_outputs = std::min(n_outputs, n_tokens);
llm_graph_result_ptr llama_context::graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype) {
return model.build_graph(
{
/*.ctx =*/ ctx,
/*.arch =*/ model.arch,
/*.hparams =*/ model.hparams,
/*.cparams =*/ cparams,
/*.ubatch =*/ ubatch,
/*.sched =*/ sched.get(),
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.memory =*/ memory.get(),
/*.cross =*/ &cross,
/*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(),
}, gf, gtype);
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
}
ggml_backend_sched_reset(sched.get());
// when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
gf_res_prev->reset();
// store the n_outputs as it is, and restore it afterwards
// TODO: not sure if needed, might simplify in the future by removing this
const auto save_n_outputs = this->n_outputs;
this->n_outputs = n_outputs;
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
auto * res = gf_res_reserve.get();
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
res->reset();
auto * gf = model.build_graph(gparams);
this->n_outputs = save_n_outputs;
// initialize scheduler with the specified graph
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
return nullptr;
}
return gf;
}
llm_graph_params llama_context::graph_params(
llm_graph_result * res,
const llama_ubatch & ubatch,
const llama_memory_context_i * mctx,
llm_graph_type gtype) const {
return {
/*.arch =*/ model.arch,
/*.hparams =*/ model.hparams,
/*.cparams =*/ cparams,
/*.ubatch =*/ ubatch,
/*.gtype =*/ gtype,
/*.sched =*/ sched.get(),
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
/*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(),
/*.res =*/ res,
};
}
ggml_status llama_context::graph_compute(
......@@ -1660,14 +1831,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
std::vector<int32_t> w_output_pos;
GGML_ASSERT(n_outputs <= n_outputs_max);
w_output_pos.resize(n_outputs);
// build a more compact representation of the output ids
for (size_t i = 0; i < n_batch(); ++i) {
// map an output id to a position in the batch
int32_t pos = output_ids[i];
int64_t pos = output_ids[i];
if (pos >= 0) {
GGML_ASSERT(pos < n_outputs);
w_output_pos[pos] = i;
......@@ -1707,10 +1876,10 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
}
}
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_write(io);
if (memory != nullptr) {
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
memory->state_write(io);
}
return io.n_bytes();
}
......@@ -1796,9 +1965,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
if (memory) {
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_read(io);
memory->state_read(io);
}
return io.n_bytes();
......@@ -1808,9 +1975,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
GGML_UNUSED(seq_id);
if (memory) {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_write(io, seq_id);
memory->state_write(io, seq_id);
}
return io.n_bytes();
......@@ -1820,9 +1985,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
GGML_UNUSED(seq_id);
if (memory) {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_read(io, seq_id);
memory->state_read(io, seq_id);
}
return io.n_bytes();
......@@ -1841,6 +2004,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
data.t_eval_ms = 1e-3 * t_eval_us;
data.n_p_eval = std::max(1, n_p_eval);
data.n_eval = std::max(1, n_eval);
data.n_reused = std::max(0, n_reused);
return data;
}
......@@ -1849,6 +2013,7 @@ void llama_context::perf_reset() {
t_start_us = ggml_time_us();
t_eval_us = n_eval = 0;
t_p_eval_us = n_p_eval = 0;
n_reused = 0;
}
//
......@@ -1927,10 +2092,7 @@ void llama_context::opt_epoch_iter(
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->clear();
llama_kv_cache_guard kv_guard(kv_self);
memory->clear(true);
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
batch.n_tokens = n_batch;
......@@ -1942,42 +2104,49 @@ void llama_context::opt_epoch_iter(
batch.logits [pos_batch] = true;
}
const auto n_tokens_all = batch.n_tokens;
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return;
}
n_queued_tokens += n_tokens_all;
const uint32_t n_tokens_all = balloc->get_n_tokens();
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
n_queued_tokens += n_tokens_all;
embd_seq.clear();
int64_t n_outputs_all = n_tokens_all;
uint32_t n_outputs_all = n_tokens_all;
llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break;
}
// reserve output buffer
if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
GGML_ABORT("TODO: handle this error");
};
for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
uint32_t pos_batch = 0;
do {
const auto & ubatch = mctx->get_ubatch();
n_outputs = ubatch.n_tokens;
// TODO: not sure if this is needed
if (!kv_self->find_slot(ubatch)) {
kv_self->defrag_sched(-1.0f);
kv_self->update(*this);
if (!kv_self->find_slot(ubatch)) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
GGML_ABORT("TODO: handle this error");
}
if (!mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
break;
}
auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
auto * res = gf_res_prev.get();
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
res->reset();
auto * gf = model.build_graph(gparams);
struct ggml_context * ctx_compute_opt;
{
......@@ -1992,6 +2161,7 @@ void llama_context::opt_epoch_iter(
}
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
ggml_opt_alloc(opt_ctx, train);
res->set_inputs(&ubatch);
{
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
......@@ -2009,10 +2179,10 @@ void llama_context::opt_epoch_iter(
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
}
ggml_free(ctx_compute_opt);
}
}
kv_guard.commit();
pos_batch += ubatch.n_tokens;
} while (mctx->next());
}
}
void llama_context::opt_epoch(
......@@ -2097,6 +2267,8 @@ llama_context_params llama_context_default_params() {
/*.flash_attn =*/ false,
/*.no_perf =*/ true,
/*.op_offload =*/ true,
/*.swa_full =*/ true,
/*.kv_unified =*/ false,
};
return result;
......@@ -2171,12 +2343,14 @@ const llama_model * llama_get_model(const llama_context * ctx) {
return &ctx->get_model();
}
// deprecated
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
return ctx->get_kv_self();
return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
}
// deprecated
void llama_kv_self_update(llama_context * ctx) {
ctx->kv_self_update();
ctx->kv_self_update(false);
}
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
......@@ -2292,27 +2466,108 @@ int32_t llama_apply_adapter_cvec(
}
//
// kv cache view
// memory
//
llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
const auto * kv = ctx->get_kv_self();
if (kv == nullptr) {
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
return {};
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
return ctx->get_memory();
}
void llama_memory_clear(llama_memory_t mem, bool data) {
if (!mem) {
return;
}
mem->clear(data);
}
bool llama_memory_seq_rm(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
if (!mem) {
return true;
}
return llama_kv_cache_view_init(*kv, n_seq_max);
return mem->seq_rm(seq_id, p0, p1);
}
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
const auto * kv = ctx->get_kv_self();
if (kv == nullptr) {
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
void llama_memory_seq_cp(
llama_memory_t mem,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
if (!mem) {
return;
}
llama_kv_cache_view_update(view, kv);
mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
void llama_memory_seq_keep(
llama_memory_t mem,
llama_seq_id seq_id) {
if (!mem) {
return;
}
mem->seq_keep(seq_id);
}
void llama_memory_seq_add(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
if (!mem) {
return;
}
mem->seq_add(seq_id, p0, p1, delta);
}
void llama_memory_seq_div(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
if (!mem) {
return;
}
mem->seq_div(seq_id, p0, p1, d);
}
llama_pos llama_memory_seq_pos_min(
llama_memory_t mem,
llama_seq_id seq_id) {
if (!mem) {
return -1;
}
return mem->seq_pos_min(seq_id);
}
llama_pos llama_memory_seq_pos_max(
llama_memory_t mem,
llama_seq_id seq_id) {
if (!mem) {
return -1;
}
return mem->seq_pos_max(seq_id);
}
bool llama_memory_can_shift(llama_memory_t mem) {
if (!mem) {
return false;
}
return mem->get_can_shift();
}
//
......@@ -2320,203 +2575,161 @@ void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view *
//
// deprecated
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
return llama_kv_self_n_tokens(ctx);
}
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self();
const auto * kv = llama_get_memory(ctx);
if (!kv) {
return 0;
}
return kv->get_n_tokens();
}
int32_t res = 0;
// deprecated
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
return llama_kv_self_used_cells(ctx);
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
const llama_pos p0 = kv->seq_pos_min(s);
const llama_pos p1 = kv->seq_pos_max(s);
if (p0 >= 0) {
res += (p1 - p0) + 1;
}
}
return res;
}
// deprecated
// note: this is the same as above - will be removed anyway, so it's ok
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self();
const auto * kv = llama_get_memory(ctx);
if (!kv) {
return 0;
}
return kv->get_used_cells();
}
int32_t res = 0;
// deprecated
void llama_kv_cache_clear(llama_context * ctx) {
llama_kv_self_clear(ctx);
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
const llama_pos p0 = kv->seq_pos_min(s);
const llama_pos p1 = kv->seq_pos_max(s);
if (p0 >= 0) {
res += (p1 - p0) + 1;
}
}
return res;
}
// deprecated
void llama_kv_self_clear(llama_context * ctx) {
auto * kv = ctx->get_kv_self();
auto * kv = llama_get_memory(ctx);
if (!kv) {
return;
}
kv->clear();
llama_memory_clear(kv, true);
}
// deprecated
bool llama_kv_cache_seq_rm(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
}
bool llama_kv_self_seq_rm(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
auto * kv = ctx->get_kv_self();
auto * kv = llama_get_memory(ctx);
if (!kv) {
return true;
}
return kv->seq_rm(seq_id, p0, p1);
return llama_memory_seq_rm(kv, seq_id, p0, p1);
}
// deprecated
void llama_kv_cache_seq_cp(
llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
}
void llama_kv_self_seq_cp(
llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
auto * kv = ctx->get_kv_self();
auto * kv = llama_get_memory(ctx);
if (!kv) {
return;
}
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
}
// deprecated
void llama_kv_cache_seq_keep(
llama_context * ctx,
llama_seq_id seq_id) {
llama_kv_self_seq_keep(ctx, seq_id);
}
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
auto * kv = ctx->get_kv_self();
auto * kv = llama_get_memory(ctx);
if (!kv) {
return;
}
kv->seq_keep(seq_id);
llama_memory_seq_keep(kv, seq_id);
}
// deprecated
void llama_kv_cache_seq_add(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
}
void llama_kv_self_seq_add(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
auto * kv = ctx->get_kv_self();
auto * kv = llama_get_memory(ctx);
if (!kv) {
return;
}
kv->seq_add(seq_id, p0, p1, delta);
llama_memory_seq_add(kv, seq_id, p0, p1, delta);
}
// deprecated
void llama_kv_cache_seq_div(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
}
void llama_kv_self_seq_div(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
auto * kv = ctx->get_kv_self();
auto * kv = llama_get_memory(ctx);
if (!kv) {
return;
}
kv->seq_div(seq_id, p0, p1, d);
llama_memory_seq_div(kv, seq_id, p0, p1, d);
}
// deprecated
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_self_seq_pos_max(ctx, seq_id);
}
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
const auto * kv = ctx->get_kv_self();
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
auto * kv = llama_get_memory(ctx);
if (!kv) {
return 0;
return -1;
}
return kv->seq_pos_max(seq_id);
return llama_memory_seq_pos_min(kv, seq_id);
}
// deprecated
void llama_kv_cache_defrag(llama_context * ctx) {
llama_kv_self_defrag(ctx);
}
void llama_kv_self_defrag(llama_context * ctx) {
auto * kv = ctx->get_kv_self();
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
auto * kv = llama_get_memory(ctx);
if (!kv) {
return;
return -1;
}
// force defrag
kv->defrag_sched(-1.0f);
return llama_memory_seq_pos_max(kv, seq_id);
}
// deprecated
bool llama_kv_cache_can_shift(const llama_context * ctx) {
return llama_kv_self_can_shift(ctx);
void llama_kv_self_defrag(llama_context * ctx) {
// force defrag
ctx->kv_self_defrag_sched();
}
// deprecated
bool llama_kv_self_can_shift(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self();
auto * kv = llama_get_memory(ctx);
if (!kv) {
return false;
}
return kv->get_can_shift();
}
// deprecated
void llama_kv_cache_update(llama_context * ctx) {
llama_kv_self_update(ctx);
return llama_memory_can_shift(kv);
}
// llama state API
......@@ -2642,7 +2855,7 @@ int32_t llama_decode(
llama_context * ctx,
llama_batch batch) {
const int ret = ctx->decode(batch);
if (ret != 0) {
if (ret != 0 && ret != 1) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}
......@@ -2676,6 +2889,7 @@ void llama_perf_context_print(const llama_context * ctx) {
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused);
}
void llama_perf_context_reset(llama_context * ctx) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment