Unverified Commit 66b25392 authored by Bruce MacDonald's avatar Bruce MacDonald Committed by GitHub
Browse files

runner: clear cache when shift is not possible (#9433)

Clear KV cache when shift operation is not supported by model.
Added KvCacheCanShift() check to handle models that can't perform cache shifts,
falling back to full cache clear while preserving logical token history to
maintain expected behavior when context window fills up.
parent ef27d52e
......@@ -166,6 +166,10 @@ func (c *Context) KvCacheDefrag() {
C.llama_kv_cache_defrag(c.c)
}
func (c *Context) KvCacheCanShift() bool {
return bool(C.llama_kv_cache_can_shift(c.c))
}
// Get the embeddings for a sequence id
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
......
......@@ -213,8 +213,16 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
return discard
}
// Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning).
type ErrReprocessInputs struct {
Inputs []input
}
func (e *ErrReprocessInputs) Error() string {
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
}
// ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history
// and shifting the newest half into that space (saving numKeep inputs at the beginning).
//
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
......@@ -222,7 +230,8 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
}
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
inputLen := len(slot.Inputs)
discard := c.ShiftDiscard(inputLen, numKeep)
if discard <= 0 {
return nil
......@@ -231,16 +240,42 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
"keep", numKeep, "discard", discard)
// TODO (jessegross): KV cache removal can fail for certain types of models
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
var shiftFailed bool
if c.lc.KvCacheCanShift() {
// For models that support shifting, attempt to shift the KV cache
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
shiftFailed = true
slog.Debug("kv cache removal not supported, clearing cache and returning inputs for reprocessing", "id", slot.Id)
} else {
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, inputLen, -discard)
}
} else {
// For models that don't support shifting
shiftFailed = true
slog.Debug("kv cache cannot shift, clearing cache and returning inputs for reprocessing", "id", slot.Id)
}
if shiftFailed {
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
newInputs := make([]input, numKeep+inputLen-(numKeep+discard))
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
// Clear the entire KV cache
_ = c.lc.KvCacheSeqRm(slot.Id, 0, -1)
// Reset the slot inputs since we've cleared the cache
slot.Inputs = []input{}
// Return error with inputs that need to be reprocessed
return &ErrReprocessInputs{Inputs: newInputs}
}
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
for i := numKeep + discard; i < len(slot.Inputs); i++ {
// Standard shift succeeded - update input array
for i := numKeep + discard; i < inputLen; i++ {
slot.Inputs[i-discard] = slot.Inputs[i]
}
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
slot.Inputs = slot.Inputs[:inputLen-discard]
return nil
}
......@@ -389,7 +389,15 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
var reprocess *ErrReprocessInputs
if errors.As(err, &reprocess) {
// Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...)
// Continue processing as normal
continue
} else {
return err
}
}
} else {
break
......
......@@ -239,6 +239,14 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
return discard
}
type ErrReprocessInputs struct {
Inputs []input.Input
}
func (e *ErrReprocessInputs) Error() string {
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
}
// Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning).
//
......@@ -258,11 +266,23 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
"keep", numKeep, "discard", discard)
// TODO (jessegross): KV cache removal can fail for certain types of models
if c.cache != nil {
err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
if err != nil {
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
slog.Debug("kv cache removal unsupported, clearing cache and returning inputs for reprocessing",
"id", slot.Id, "error", err)
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
// Reset the cache
_ = c.cache.Remove(slot.Id, 0, -1)
slot.Inputs = []input.Input{}
// Return error with inputs that need to be reprocessed
return &ErrReprocessInputs{Inputs: newInputs}
}
}
......
package ollamarunner
import (
"errors"
"fmt"
"image"
"testing"
"time"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
......@@ -425,3 +428,91 @@ func TestLoadCacheSlot(t *testing.T) {
})
}
}
// Mock implementation of the Cache interface
type mockCache struct {
shouldFail bool
}
// Implement only the methods needed for the test
func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
if m.shouldFail {
return fmt.Errorf("mock cache removal error")
}
return nil
}
// Stub implementations for other interface methods
func (m *mockCache) SetLayer(layer int) {}
func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil }
func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {}
func (m *mockCache) Close() {}
func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil }
func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
func (m *mockCache) SetConfig(ml.CacheConfig) {}
func TestShiftCacheSlot(t *testing.T) {
tests := []struct {
name string
numCtx int32
inputs []input.Input
numKeep int32
cacheErr bool
wantErr any
wantInputsLen int
}{
{
name: "Normal shift",
numCtx: 10,
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2,
cacheErr: false, // No error
wantErr: nil,
wantInputsLen: 6, // After discarding 4 tokens
},
{
name: "Cache removal fails",
numCtx: 10,
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2,
cacheErr: true,
wantErr: &ErrReprocessInputs{},
wantInputsLen: 0, // Original inputs should be cleared
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mock := &mockCache{shouldFail: tt.cacheErr}
c := InputCache{
numCtx: tt.numCtx,
cache: mock,
}
slot := &InputCacheSlot{
Id: 123,
Inputs: make([]input.Input, len(tt.inputs)),
}
copy(slot.Inputs, tt.inputs)
err := c.ShiftCacheSlot(slot, tt.numKeep)
if tt.wantErr != nil {
if err == nil {
t.Errorf("Expected error but got nil")
return
}
if !errors.As(err, &tt.wantErr) {
t.Errorf("Expected error of type %T but got %T: %v", tt.wantErr, err, err)
}
} else if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(slot.Inputs) != tt.wantInputsLen {
t.Errorf("Slot inputs length after operation: got %v, want %v", len(slot.Inputs), tt.wantInputsLen)
}
})
}
}
......@@ -407,7 +407,15 @@ func (s *Server) processBatch() error {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
var reprocess *ErrReprocessInputs
if errors.As(err, &reprocess) {
// Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...)
// Skip this sequence but continue processing the rest
continue
} else {
return err
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment