Commit c826e574 authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

runner.go: Better abstract vision model integration



-Update mllama to take the cross attention state as embeddings in
a batch, more similar to how Llava handles it. This improves
integration with the input cache.
-Pass locations in a prompt for embeddings using tags similar to Llava.
-Abstract interface to vision models so the main runner accesses Clip
and Mllama similarly
Co-authored-by: default avatarMichael Yang <mxyng@pm.me>
parent 712e99d4
...@@ -2699,7 +2699,7 @@ struct llama_hparams { ...@@ -2699,7 +2699,7 @@ struct llama_hparams {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
   
bool cross_attention_layer(uint32_t il) const { bool cross_attention_layers(uint32_t il) const {
return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end(); return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
} }
}; };
...@@ -2731,6 +2731,9 @@ struct llama_cparams { ...@@ -2731,6 +2731,9 @@ struct llama_cparams {
bool offload_kqv; bool offload_kqv;
bool flash_attn; bool flash_attn;
bool no_perf; bool no_perf;
// TODO (jmorganca): this should most likely be passed in as part of a batch
// and not set on the context for all batches.
bool cross_attn = false;
   
enum llama_pooling_type pooling_type; enum llama_pooling_type pooling_type;
   
...@@ -3542,10 +3545,6 @@ struct llama_context { ...@@ -3542,10 +3545,6 @@ struct llama_context {
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
   
// TODO (jmorganca): this should most likely be passed in as part of a batch
// and not set on the context for all batches.
float * cross_attn_state = nullptr;
bool cross_attn_state_first_pass = true;
struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061] struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
}; };
   
...@@ -3782,7 +3781,7 @@ static bool llama_kv_cache_init( ...@@ -3782,7 +3781,7 @@ static bool llama_kv_cache_init(
   
for (int i = 0; i < (int) n_layer; i++) { for (int i = 0; i < (int) n_layer; i++) {
// for cross attention layers // for cross attention layers
if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layer(i)) { if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i)); ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i)); ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
...@@ -7389,7 +7388,7 @@ static bool llm_load_tensors( ...@@ -7389,7 +7388,7 @@ static bool llm_load_tensors(
   
auto & layer = model.layers[i]; auto & layer = model.layers[i];
   
if (hparams.cross_attention_layer(i)) { if (hparams.cross_attention_layers(i)) {
layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128}); layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128});
layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024}); layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024});
layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd}); layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd});
...@@ -9346,7 +9345,7 @@ static struct ggml_tensor * llm_build_inp_embd( ...@@ -9346,7 +9345,7 @@ static struct ggml_tensor * llm_build_inp_embd(
   
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens); inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
} else { } else {
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
inpL = lctx.inp_embd; inpL = lctx.inp_embd;
ggml_set_input(lctx.inp_embd); ggml_set_input(lctx.inp_embd);
} }
...@@ -9368,11 +9367,10 @@ static struct ggml_tensor * llm_build_inp_cross_attn_state( ...@@ -9368,11 +9367,10 @@ static struct ggml_tensor * llm_build_inp_cross_attn_state(
const llm_build_cb & cb) { const llm_build_cb & cb) {
const int64_t n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd;
   
struct ggml_tensor * inpCAS; struct ggml_tensor * inpCAS = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
lctx.inp_cross_attn_state = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4); cb(inpCAS, "inp_cross_attn_state", -1);
cb(lctx.inp_cross_attn_state, "inp_cross_attn_state", -1); ggml_set_input(inpCAS);
ggml_set_input(lctx.inp_cross_attn_state); lctx.inp_cross_attn_state = inpCAS;
inpCAS = lctx.inp_cross_attn_state;
   
return inpCAS; return inpCAS;
} }
...@@ -10979,8 +10977,8 @@ struct llm_build_context { ...@@ -10979,8 +10977,8 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il); LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il); cb(cur, "attn_norm", il);
   
if (hparams.cross_attention_layer(il)) { if (hparams.cross_attention_layers(il)) {
if (!lctx.cross_attn_state) { if (!batch.embd && !cparams.cross_attn) {
continue; continue;
} }
   
...@@ -10991,42 +10989,28 @@ struct llm_build_context { ...@@ -10991,42 +10989,28 @@ struct llm_build_context {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
   
Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
cb(Qcur, "Qcur", il);
// TODO: is this required?
Qcur = ggml_cont(ctx0, Qcur);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
   
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il); Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
   
struct ggml_tensor * Kcur; struct ggml_tensor * Kcur, * Vcur;
if (lctx.cross_attn_state_first_pass) { if (batch.embd) {
Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS); Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
   
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
   
Kcur = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
cb(Kcur, "Kcur", il);
// TODO: is this required?
Kcur = ggml_cont(ctx0, Kcur);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
   
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il); Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
   
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il])); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
} else {
Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
cb(Kcur, "Kcur (view)", il);
}
   
struct ggml_tensor * Vcur;
if (lctx.cross_attn_state_first_pass) {
Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS); Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
   
...@@ -11038,6 +11022,9 @@ struct llm_build_context { ...@@ -11038,6 +11022,9 @@ struct llm_build_context {
   
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il])); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
} else { } else {
Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
cb(Kcur, "Kcur (view)", il);
Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]); Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
cb(Vcur, "Vcur (view)", il); cb(Vcur, "Vcur (view)", il);
} }
...@@ -11045,11 +11032,8 @@ struct llm_build_context { ...@@ -11045,11 +11032,8 @@ struct llm_build_context {
struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur); struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
cb(kq, "kq", il); cb(kq, "kq", il);
   
kq = ggml_scale_inplace(ctx0, kq, 1.0f/sqrtf(float(n_embd_head)));
cb(kq, "kq_scaled", il);
// TODO: apply causal masks // TODO: apply causal masks
struct ggml_tensor * kq_soft_max = ggml_soft_max_inplace(ctx0, kq); struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
cb(kq_soft_max, "kq_soft_max", il); cb(kq_soft_max, "kq_soft_max", il);
   
Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur)); Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
...@@ -11139,8 +11123,8 @@ struct llm_build_context { ...@@ -11139,8 +11123,8 @@ struct llm_build_context {
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
   
cur = llm_build_kv(ctx0, lctx, kv_self, gf, cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
   
   
if (il == n_layer - 1) { if (il == n_layer - 1) {
...@@ -17197,10 +17181,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { ...@@ -17197,10 +17181,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
} }
   
if (batch.embd) { if (batch.embd) {
const int64_t n_embd = hparams.n_embd; if (lctx.inp_cross_attn_state && lctx.inp_cross_attn_state->buffer) {
const int64_t n_tokens = batch.n_tokens; ggml_backend_tensor_set(lctx.inp_cross_attn_state, batch.embd, 0, ggml_nbytes(lctx.inp_cross_attn_state));
// zero out inp_embd since it's not used
float * inp_embd_data = (float *)lctx.inp_embd->data;
for (int i = 0; i < ggml_nelements(lctx.inp_embd); ++i) {
inp_embd_data[i] = 0.0f;
}
} else {
const int64_t n_embd = hparams.n_embd;
const int64_t n_tokens = batch.n_tokens;
   
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
}
} }
   
if (batch.pos && lctx.inp_pos) { if (batch.pos && lctx.inp_pos) {
...@@ -17209,14 +17202,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { ...@@ -17209,14 +17202,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
} }
   
// TODO (jmorganca): this might copy a lot of data on every request of a
// single generation even though it doesn't change, so we should
// find a way to not set this more than one time per image
if (lctx.inp_cross_attn_state &&
lctx.inp_cross_attn_state->buffer) {
ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state));
}
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
...@@ -17789,7 +17774,7 @@ static int llama_decode_internal( ...@@ -17789,7 +17774,7 @@ static int llama_decode_internal(
n_outputs = 1; n_outputs = 1;
} }
   
lctx.sbatch.from_batch(batch_all, n_embd, lctx.sbatch.from_batch(batch_all, batch_all.n_embd,
/* simple_split */ !kv_self.recurrent, /* simple_split */ !kv_self.recurrent,
/* logits_all */ n_outputs == n_tokens_all); /* logits_all */ n_outputs == n_tokens_all);
   
...@@ -17899,10 +17884,6 @@ static int llama_decode_internal( ...@@ -17899,10 +17884,6 @@ static int llama_decode_internal(
   
llama_set_inputs(lctx, ubatch); llama_set_inputs(lctx, ubatch);
   
// TODO: replace with something better to find out if its
// our first actual pass
lctx.cross_attn_state_first_pass = false;
llama_graph_compute(lctx, gf, n_threads, threadpool); llama_graph_compute(lctx, gf, n_threads, threadpool);
   
// update the kv ring buffer // update the kv ring buffer
...@@ -18086,7 +18067,7 @@ static int llama_encode_internal( ...@@ -18086,7 +18067,7 @@ static int llama_encode_internal(
   
const int64_t n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd;
   
lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); lctx.sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true);
   
const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens); const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
   
...@@ -20194,11 +20175,6 @@ struct llama_context * llama_new_context_with_model( ...@@ -20194,11 +20175,6 @@ struct llama_context * llama_new_context_with_model(
return ctx; return ctx;
} }
   
void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state) {
ctx->cross_attn_state_first_pass = true;
ctx->cross_attn_state = cross_attn_state;
}
void llama_free(struct llama_context * ctx) { void llama_free(struct llama_context * ctx) {
delete ctx; delete ctx;
} }
...@@ -21686,6 +21662,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { ...@@ -21686,6 +21662,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
ctx->cparams.causal_attn = causal_attn; ctx->cparams.causal_attn = causal_attn;
} }
   
void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) {
ctx->cparams.cross_attn = cross_attention;
}
struct llama_batch llama_batch_get_one( struct llama_batch llama_batch_get_one(
llama_token * tokens, llama_token * tokens,
int32_t n_tokens, int32_t n_tokens,
...@@ -21695,6 +21675,7 @@ struct llama_batch llama_batch_get_one( ...@@ -21695,6 +21675,7 @@ struct llama_batch llama_batch_get_one(
/*n_tokens =*/ n_tokens, /*n_tokens =*/ n_tokens,
/*tokens =*/ tokens, /*tokens =*/ tokens,
/*embd =*/ nullptr, /*embd =*/ nullptr,
/*n_embd =*/ 0,
/*pos =*/ nullptr, /*pos =*/ nullptr,
/*n_seq_id =*/ nullptr, /*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr, /*seq_id =*/ nullptr,
...@@ -21710,6 +21691,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ ...@@ -21710,6 +21691,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
/*n_tokens =*/ 0, /*n_tokens =*/ 0,
/*tokens =*/ nullptr, /*tokens =*/ nullptr,
/*embd =*/ nullptr, /*embd =*/ nullptr,
/*n_embd =*/ 0,
/*pos =*/ nullptr, /*pos =*/ nullptr,
/*n_seq_id =*/ nullptr, /*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr, /*seq_id =*/ nullptr,
...@@ -21721,6 +21703,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ ...@@ -21721,6 +21703,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
   
if (embd) { if (embd) {
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
batch.n_embd = embd;
} else { } else {
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
} }
......
...@@ -111,6 +111,28 @@ func PrintSystemInfo() string { ...@@ -111,6 +111,28 @@ func PrintSystemInfo() string {
return C.GoString(C.llama_print_system_info()) + compiler return C.GoString(C.llama_print_system_info()) + compiler
} }
func GetModelArch(modelPath string) (string, error) {
mp := C.CString(modelPath)
defer C.free(unsafe.Pointer(mp))
gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
if gguf_ctx == nil {
return "", errors.New("unable to load model file")
}
defer C.gguf_free(gguf_ctx)
key := C.CString("general.architecture")
defer C.free(unsafe.Pointer(key))
arch_index := C.gguf_find_key(gguf_ctx, key)
if int(arch_index) < 0 {
return "", errors.New("unknown model architecture")
}
arch := C.gguf_get_val_str(gguf_ctx, arch_index)
return C.GoString(arch), nil
}
type ContextParams struct { type ContextParams struct {
c C.struct_llama_context_params c C.struct_llama_context_params
} }
...@@ -443,71 +465,36 @@ func Quantize(infile, outfile string, ftype uint32) error { ...@@ -443,71 +465,36 @@ func Quantize(infile, outfile string, ftype uint32) error {
return nil return nil
} }
// llava // vision processing
type ClipContext struct { type ClipContext struct {
c *C.struct_clip_ctx c *C.struct_clip_ctx
m *C.struct_mllama_ctx
IsMllama bool
embedPin runtime.Pinner
pinned bool
}
func getVisionArch(mp *C.char) (string, error) {
gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
if gguf_ctx == nil {
return "", errors.New("unable to load vision projector")
}
defer C.gguf_free(gguf_ctx)
arch_index := C.gguf_find_key(gguf_ctx, C.CString("general.architecture"))
if int(arch_index) < 0 {
return "", errors.New("unknown vision model architecture")
}
arch := C.gguf_get_val_str(gguf_ctx, arch_index)
return C.GoString(arch), nil
} }
func NewClipContext(modelPath string) (*ClipContext, error) { func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) {
mp := C.CString(modelPath) mp := C.CString(modelPath)
defer C.free(unsafe.Pointer(mp)) defer C.free(unsafe.Pointer(mp))
c := C.clip_model_load(mp, 1)
arch, err := getVisionArch(mp) projEmbedSize := int(C.clip_n_mmproj_embd(c))
if err != nil { modelEmbedSize := llamaContext.Model().NEmbd()
return nil, err if projEmbedSize != modelEmbedSize {
return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
} }
var cc ClipContext return &ClipContext{c: c}, nil
if arch == "clip" {
cc.c = C.clip_model_load(mp, 1)
} else if arch == "mllama" {
cc.m = C.mllama_model_load(mp, 1)
cc.IsMllama = true
} else {
return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
}
// XXX: check embedding size?
return &cc, nil
} }
func (c *ClipContext) Free() { func (c *ClipContext) Free() {
if c.c != nil { C.clip_free(c.c)
C.clip_free(c.c)
}
if c.m != nil {
C.mllama_free(c.m)
}
} }
func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte) [][]float32 { func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) [][]float32 {
c := C.llava_image_embed_make_with_bytes(clipContext.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data))) l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
numTokens := int(c.n_image_pos) numTokens := int(l.n_image_pos)
numEmbed := llamaContext.Model().NEmbd() numEmbed := llamaContext.Model().NEmbd()
s := unsafe.Slice((*float32)(c.embed), numEmbed*numTokens) s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens)
embed := make([][]float32, numTokens) embed := make([][]float32, numTokens)
rows := make([]float32, len(s)) rows := make([]float32, len(s))
...@@ -517,51 +504,57 @@ func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data [] ...@@ -517,51 +504,57 @@ func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []
embed[i] = rows[i*numEmbed : (i+1)*numEmbed] embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
} }
C.llava_image_embed_free(c) C.llava_image_embed_free(l)
return embed return embed
} }
func NewMllamaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte, aspectRatioId int) [][]float32 { type MllamaContext struct {
c *C.struct_mllama_ctx
}
func NewMllamaContext(llamaContext *Context, modelPath string) (*MllamaContext, error) {
mp := C.CString(modelPath)
defer C.free(unsafe.Pointer(mp))
c := C.mllama_model_load(mp, 1)
projEmbedSize := int(C.mllama_n_embd(c))
modelEmbedSize := llamaContext.Model().NEmbd()
if projEmbedSize != modelEmbedSize {
return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
}
return &MllamaContext{c: c}, nil
}
func (m *MllamaContext) Free() {
C.mllama_free(m.c)
}
func (m *MllamaContext) NewEmbed(llamaContext *Context, data []byte, aspectRatioId int) [][]float32 {
img := C.mllama_image_init() img := C.mllama_image_init()
defer C.mllama_image_free(img) defer C.mllama_image_free(img)
C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img) C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img)
numTokens := int(C.mllama_n_positions(clipContext.m) * C.mllama_n_tiles(clipContext.m)) rows := make([]float32, m.EmbedSize(llamaContext))
numEmbed := llamaContext.Model().NEmbd() C.mllama_image_encode(m.c, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0])))
rows := make([]float32, numEmbed*numTokens)
C.mllama_image_encode(clipContext.m, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0])))
embed := make([][]float32, numTokens) embed := make([][]float32, 1)
for i := range embed { embed[0] = rows
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
}
return embed return embed
} }
// This really needs to be set on a batch instead func (m *MllamaContext) EmbedSize(llamaContext *Context) int {
func MllamaSetCrossAttn(llamaContext *Context, clipContext *ClipContext, embed [][]float32) { numTokens := int(C.mllama_n_positions(m.c) * C.mllama_n_tiles(m.c))
if embed != nil { numEmbed := llamaContext.Model().NEmbd()
if clipContext.pinned {
panic("Cross attention state already pinned")
}
embedData := &embed[0][0]
clipContext.embedPin.Pin(embedData)
clipContext.pinned = true
C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(unsafe.Pointer(embedData))) return numTokens * numEmbed
} else { }
C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(C.NULL))
if clipContext.pinned { func (c *Context) SetCrossAttention(state bool) {
clipContext.embedPin.Unpin() C.llama_set_cross_attention(c.c, C.bool(state))
clipContext.pinned = false
}
}
} }
// sampling // sampling
......
...@@ -266,6 +266,7 @@ extern "C" { ...@@ -266,6 +266,7 @@ extern "C" {
llama_token * token; llama_token * token;
float * embd; float * embd;
int32_t n_embd;
llama_pos * pos; llama_pos * pos;
int32_t * n_seq_id; int32_t * n_seq_id;
llama_seq_id ** seq_id; llama_seq_id ** seq_id;
...@@ -451,7 +452,7 @@ extern "C" { ...@@ -451,7 +452,7 @@ extern "C" {
// TODO (jmorganca): this should most likely be passed in as part of a batch // TODO (jmorganca): this should most likely be passed in as part of a batch
// and not set on the context for all batches. // and not set on the context for all batches.
LLAMA_API void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state); LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
// Frees all allocated memory // Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx); LLAMA_API void llama_free(struct llama_context * ctx);
......
...@@ -435,7 +435,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ ...@@ -435,7 +435,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
if (n_eval > n_batch) { if (n_eval > n_batch) {
n_eval = n_batch; n_eval = n_batch;
} }
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, }; llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), n_embd, nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) { if (llama_decode(ctx_llama, batch)) {
LOG_ERR("%s : failed to eval\n", __func__); LOG_ERR("%s : failed to eval\n", __func__);
return false; return false;
......
This diff is collapsed.
...@@ -2,7 +2,6 @@ package main ...@@ -2,7 +2,6 @@ package main
import ( import (
"errors" "errors"
"hash/maphash"
"log/slog" "log/slog"
"reflect" "reflect"
"time" "time"
...@@ -20,10 +19,6 @@ type InputCache struct { ...@@ -20,10 +19,6 @@ type InputCache struct {
// optimize cache eviction for multiple users // optimize cache eviction for multiple users
multiUserCache bool multiUserCache bool
// cache of images to embeddings
images []imageCache
imageHash maphash.Hash
lc *llama.Context lc *llama.Context
} }
...@@ -41,7 +36,6 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b ...@@ -41,7 +36,6 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b
numCtx: kvSize / numSlots, numCtx: kvSize / numSlots,
slots: slots, slots: slots,
multiUserCache: multiUserCache, multiUserCache: multiUserCache,
images: make([]imageCache, numSlots),
lc: lc, lc: lc,
} }
} }
...@@ -211,55 +205,3 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int, numDiscar ...@@ -211,55 +205,3 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int, numDiscar
} }
slot.Inputs = slot.Inputs[:len(slot.Inputs)-numDiscard] slot.Inputs = slot.Inputs[:len(slot.Inputs)-numDiscard]
} }
// Locking: Lookup and store operations on imageCache require a lock
// to be held that serializes these with each other. Hash does not
// require a lock nor they need to be serialized with InputCacheSlot.
type imageCache struct {
key uint64
val [][]float32
lastUsed time.Time
}
func (c *InputCache) HashImage(image []byte) uint64 {
c.imageHash.Reset()
_, _ = c.imageHash.Write(image)
return c.imageHash.Sum64()
}
var ErrImageNotFound = errors.New("image not found in cache")
func (c *InputCache) FindImage(hash uint64) ([][]float32, error) {
for i := range c.images {
if c.images[i].key == hash {
slog.Debug("loading image embeddings from cache", "entry", i)
c.images[i].lastUsed = time.Now()
return c.images[i].val, nil
}
}
return nil, ErrImageNotFound
}
func (c *InputCache) AddImage(hash uint64, embed [][]float32) {
best := time.Now()
var bestImage int
for i := range c.images {
if c.images[i].key == hash {
bestImage = i
break
}
if c.images[i].lastUsed.Compare(best) < 0 {
best = c.images[i].lastUsed
bestImage = i
}
}
slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
c.images[bestImage].key = hash
c.images[bestImage].val = embed
c.images[bestImage].lastUsed = time.Now()
}
package main package main
import ( import (
"reflect"
"testing" "testing"
"time" "time"
) )
...@@ -228,77 +227,3 @@ func TestFindCacheSlot(t *testing.T) { ...@@ -228,77 +227,3 @@ func TestFindCacheSlot(t *testing.T) {
}) })
} }
} }
func TestImageCache(t *testing.T) {
cache := NewInputCache(nil, 2048, 4, false)
valA := [][]float32{{0.1, 0.2}, {0.3}}
valB := [][]float32{{0.4}, {0.5}, {0.6}}
valC := [][]float32{{0.7}}
valD := [][]float32{{0.8}}
valE := [][]float32{{0.9}}
// Empty cache
result, err := cache.FindImage(0x5adb61d31933a946)
if err != ErrImageNotFound {
t.Errorf("found result in empty cache: result %v, err %v", result, err)
}
// Insert A
cache.AddImage(0x5adb61d31933a946, valA)
result, err = cache.FindImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Insert B
cache.AddImage(0x011551369a34a901, valB)
result, err = cache.FindImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.FindImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valB) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Replace B with C
cache.AddImage(0x011551369a34a901, valC)
result, err = cache.FindImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.FindImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valC) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Evict A
cache.AddImage(0x756b218a517e7353, valB)
cache.AddImage(0x75e5e8d35d7e3967, valD)
cache.AddImage(0xd96f7f268ca0646e, valE)
result, err = cache.FindImage(0x5adb61d31933a946)
if reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.FindImage(0x756b218a517e7353)
if !reflect.DeepEqual(result, valB) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.FindImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valC) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.FindImage(0x75e5e8d35d7e3967)
if !reflect.DeepEqual(result, valD) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.FindImage(0xd96f7f268ca0646e)
if !reflect.DeepEqual(result, valE) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
}
package main
import (
"errors"
"fmt"
"hash/maphash"
"log/slog"
"sync"
"time"
"github.com/ollama/ollama/llama"
)
const imageCacheSize = 4
type ImageContext struct {
// mu is required to be held when generating embeddings or accessing the cache
mu sync.Mutex
clip *llama.ClipContext
mllama *llama.MllamaContext
// cache of images to embeddings
images []imageCache
imageHash maphash.Hash
}
func NewImageContext(llamaContext *llama.Context, modelPath string) (*ImageContext, error) {
arch, err := llama.GetModelArch(modelPath)
if err != nil {
return nil, fmt.Errorf("unable to determine vision architecture: %w (%s)", err, modelPath)
}
var c ImageContext
if arch == "clip" {
c.clip, err = llama.NewClipContext(llamaContext, modelPath)
} else if arch == "mllama" {
c.mllama, err = llama.NewMllamaContext(llamaContext, modelPath)
} else {
return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
}
if err != nil {
return nil, err
}
c.images = make([]imageCache, imageCacheSize)
return &c, nil
}
func (c *ImageContext) Free(modelPath string) {
if c == nil {
return
}
if c.clip != nil {
c.clip.Free()
}
if c.mllama != nil {
c.mllama.Free()
}
}
func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspectRatioId int) [][]float32 {
if c == nil {
return nil
}
hash := c.hashImage(data)
c.mu.Lock()
defer c.mu.Unlock()
embed, err := c.findImage(hash)
if err != nil {
if c.mllama != nil {
embed = c.mllama.NewEmbed(llamaContext, data, aspectRatioId)
} else if c.clip != nil {
embed = c.clip.NewEmbed(llamaContext, data)
} else {
return nil
}
c.addImage(hash, embed)
}
return embed
}
func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
if c != nil && c.mllama != nil {
return c.mllama.EmbedSize(llamaContext)
} else {
return llamaContext.Model().NEmbd()
}
}
type imageCache struct {
key uint64
val [][]float32
lastUsed time.Time
}
func (c *ImageContext) hashImage(image []byte) uint64 {
c.imageHash.Reset()
_, _ = c.imageHash.Write(image)
return c.imageHash.Sum64()
}
var errImageNotFound = errors.New("image not found in cache")
func (c *ImageContext) findImage(hash uint64) ([][]float32, error) {
for i := range c.images {
if c.images[i].key == hash {
slog.Debug("loading image embeddings from cache", "entry", i)
c.images[i].lastUsed = time.Now()
return c.images[i].val, nil
}
}
return nil, errImageNotFound
}
func (c *ImageContext) addImage(hash uint64, embed [][]float32) {
best := time.Now()
var bestImage int
for i := range c.images {
if c.images[i].key == hash {
bestImage = i
break
}
if c.images[i].lastUsed.Compare(best) < 0 {
best = c.images[i].lastUsed
bestImage = i
}
}
slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
c.images[bestImage].key = hash
c.images[bestImage].val = embed
c.images[bestImage].lastUsed = time.Now()
}
package main
import (
"reflect"
"testing"
)
func TestImageCache(t *testing.T) {
cache := ImageContext{images: make([]imageCache, 4)}
valA := [][]float32{{0.1, 0.2}, {0.3}}
valB := [][]float32{{0.4}, {0.5}, {0.6}}
valC := [][]float32{{0.7}}
valD := [][]float32{{0.8}}
valE := [][]float32{{0.9}}
// Empty cache
result, err := cache.findImage(0x5adb61d31933a946)
if err != errImageNotFound {
t.Errorf("found result in empty cache: result %v, err %v", result, err)
}
// Insert A
cache.addImage(0x5adb61d31933a946, valA)
result, err = cache.findImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Insert B
cache.addImage(0x011551369a34a901, valB)
result, err = cache.findImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valB) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Replace B with C
cache.addImage(0x011551369a34a901, valC)
result, err = cache.findImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valC) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Evict A
cache.addImage(0x756b218a517e7353, valB)
cache.addImage(0x75e5e8d35d7e3967, valD)
cache.addImage(0xd96f7f268ca0646e, valE)
result, err = cache.findImage(0x5adb61d31933a946)
if reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x756b218a517e7353)
if !reflect.DeepEqual(result, valB) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valC) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x75e5e8d35d7e3967)
if !reflect.DeepEqual(result, valD) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0xd96f7f268ca0646e)
if !reflect.DeepEqual(result, valE) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
}
...@@ -190,57 +190,22 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { ...@@ -190,57 +190,22 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
return nil, fmt.Errorf("invalid image index: %d", n) return nil, fmt.Errorf("invalid image index: %d", n)
} }
hash := s.cache.HashImage(images[imageIndex].Data) embed := s.image.NewEmbed(s.lc, images[imageIndex].Data, images[imageIndex].AspectRatioID)
// Vision models cannot be accessed concurrently
s.clip.mu.Lock()
embed, err := s.cache.FindImage(hash)
if err != nil {
embed = llama.NewLlavaImageEmbed(s.lc, s.clip.cc, images[imageIndex].Data)
s.cache.AddImage(hash, embed)
}
s.clip.mu.Unlock()
for _, e := range embed { for _, e := range embed {
inputs = append(inputs, input{embed: e}) inputs = append(inputs, input{embed: e})
} }
} }
} }
if s.clip.cc != nil {
var embed [][]float32
if s.clip.cc.IsMllama && len(images) >= 1 {
hash := s.cache.HashImage(images[0].Data)
s.clip.mu.Lock()
var err error
embed, err = s.cache.FindImage(hash)
if err != nil {
embed = llama.NewMllamaImageEmbed(s.lc, s.clip.cc, images[0].Data, images[0].AspectRatioID)
s.cache.AddImage(hash, embed)
}
s.clip.mu.Unlock()
}
s.mu.Lock()
llama.MllamaSetCrossAttn(s.lc, s.clip.cc, embed)
s.mu.Unlock()
}
return inputs, nil return inputs, nil
} }
type clip struct {
cc *llama.ClipContext
mu sync.Mutex
}
type Server struct { type Server struct {
model *llama.Model model *llama.Model
lc *llama.Context lc *llama.Context
// required for image embeddings // required for image embeddings
clip clip image *ImageContext
batchSize int batchSize int
...@@ -322,14 +287,12 @@ func flushPending(seq *Sequence) bool { ...@@ -322,14 +287,12 @@ func flushPending(seq *Sequence) bool {
func (s *Server) removeSequence(seqIndex int, reason string) { func (s *Server) removeSequence(seqIndex int, reason string) {
seq := s.seqs[seqIndex] seq := s.seqs[seqIndex]
s.lc.SetCrossAttention(false)
flushPending(seq) flushPending(seq)
seq.doneReason = reason seq.doneReason = reason
close(seq.responses) close(seq.responses)
close(seq.embedding) close(seq.embedding)
seq.cache.InUse = false seq.cache.InUse = false
if s.clip.cc != nil {
llama.MllamaSetCrossAttn(s.lc, s.clip.cc, nil)
}
s.seqs[seqIndex] = nil s.seqs[seqIndex] = nil
} }
...@@ -341,7 +304,7 @@ func (s *Server) run(ctx context.Context) { ...@@ -341,7 +304,7 @@ func (s *Server) run(ctx context.Context) {
tokenBatch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs)) tokenBatch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
defer tokenBatch.Free() defer tokenBatch.Free()
embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.lc.Model().NEmbd(), len(s.seqs)) embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.image.EmbedSize(s.lc), len(s.seqs))
defer embedBatch.Free() defer embedBatch.Free()
for { for {
...@@ -642,12 +605,20 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { ...@@ -642,12 +605,20 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
s.mu.Lock() s.mu.Lock()
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
for _, input := range seq.inputs {
if input.embed != nil {
s.lc.SetCrossAttention(true)
break
}
}
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil { if err != nil {
s.mu.Unlock() s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return return
} }
s.seqs[i] = seq s.seqs[i] = seq
s.cond.Signal() s.cond.Signal()
break break
...@@ -815,7 +786,7 @@ func (s *Server) loadModel( ...@@ -815,7 +786,7 @@ func (s *Server) loadModel(
if ppath != "" { if ppath != "" {
var err error var err error
s.clip.cc, err = llama.NewClipContext(ppath) s.image, err = NewImageContext(s.lc, ppath)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
...@@ -75,11 +75,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. ...@@ -75,11 +75,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
currMsgIdx := n currMsgIdx := n
if isMllama { for cnt, msg := range msgs[currMsgIdx:] {
lastMsgIdx := len(msgs) - 1 prefix := ""
for i := lastMsgIdx; i >= currMsgIdx; i-- { imgPrompt := ""
if len(msgs[i].Images) > 0 { prompt := msg.Content
data, aspectRatioID, err := imageproc.Preprocess(msgs[i].Images[0])
for _, i := range msg.Images {
var imgData llm.ImageData
if isMllama {
data, aspectRatioID, err := imageproc.Preprocess(i)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
...@@ -90,37 +95,30 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. ...@@ -90,37 +95,30 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
return "", nil, err return "", nil, err
} }
imgData := llm.ImageData{ imgData = llm.ImageData{
ID: len(images),
Data: buf.Bytes(), Data: buf.Bytes(),
AspectRatioID: aspectRatioID, AspectRatioID: aspectRatioID,
} }
imgPrompt = "<|image|>"
msgs[i].Content = strings.TrimSpace("<|image|>" + msgs[i].Content) } else {
images = append(images, imgData) imgData = llm.ImageData{
break
}
}
} else {
for cnt, msg := range msgs[currMsgIdx:] {
prefix := ""
prompt := msg.Content
for _, i := range msg.Images {
imgData := llm.ImageData{
ID: len(images), ID: len(images),
Data: i, Data: i,
} }
imgPrompt = " "
}
imgTag := fmt.Sprintf("[img-%d]", imgData.ID) imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
if !strings.Contains(prompt, "[img]") { if !strings.Contains(prompt, "[img]") {
prefix += imgTag prefix += imgTag
} else { } else {
prompt = strings.Replace(prompt, "[img]", imgTag, 1) prompt = strings.Replace(prompt, "[img]", imgTag, 1)
}
images = append(images, imgData)
} }
msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + " " + prompt)
images = append(images, imgData)
} }
msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + imgPrompt + prompt)
} }
// truncate any messages that do not fit into the context window // truncate any messages that do not fit into the context window
......
...@@ -249,7 +249,7 @@ func TestChatPrompt(t *testing.T) { ...@@ -249,7 +249,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}}, {Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}},
}, },
expect: expect{ expect: expect{
prompt: "<|image|>How many hotdogs are in this image? ", prompt: "[img-0]<|image|>How many hotdogs are in this image? ",
images: [][]byte{imgBuf}, images: [][]byte{imgBuf},
aspectRatioID: 1, aspectRatioID: 1,
}, },
...@@ -264,7 +264,7 @@ func TestChatPrompt(t *testing.T) { ...@@ -264,7 +264,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf}}, {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf}},
}, },
expect: expect{ expect: expect{
prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ", prompt: "You're a test, Harry! I-I'm a what? [img-0]<|image|>A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{imgBuf}, images: [][]byte{imgBuf},
aspectRatioID: 1, aspectRatioID: 1,
}, },
...@@ -279,8 +279,8 @@ func TestChatPrompt(t *testing.T) { ...@@ -279,8 +279,8 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}}, {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}},
}, },
expect: expect{ expect: expect{
prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ", prompt: "[img-0]<|image|>You're a test, Harry! I-I'm a what? [img-1]<|image|>A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{imgBuf2}, images: [][]byte{imgBuf, imgBuf2},
aspectRatioID: 1, aspectRatioID: 1,
}, },
}, },
...@@ -294,7 +294,7 @@ func TestChatPrompt(t *testing.T) { ...@@ -294,7 +294,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "Which ones have mustard?"}, {Role: "user", Content: "Which ones have mustard?"},
}, },
expect: expect{ expect: expect{
prompt: "<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ", prompt: "[img-0]<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ",
images: [][]byte{imgBuf}, images: [][]byte{imgBuf},
aspectRatioID: 1, aspectRatioID: 1,
}, },
......
...@@ -205,7 +205,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -205,7 +205,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
images[i] = llm.ImageData{Data: buf.Bytes(), AspectRatioID: aspectRatioID} images[i] = llm.ImageData{ID: i, Data: buf.Bytes(), AspectRatioID: aspectRatioID}
} else { } else {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]} images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
} }
...@@ -239,11 +239,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -239,11 +239,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
for _, i := range images { for _, i := range images {
imgPrompt := ""
if isMllama { if isMllama {
msgs = append(msgs, api.Message{Role: "user", Content: "<|image|>"}) imgPrompt = "<|image|>"
} else {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
} }
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)})
} }
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt}) values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment