"tests/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "55fadb4c4ed9546bd9c27bb1241f83cf62cec1ab"
Commit 65973ceb authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

runner.go: Make KV entry accounting more robust

The structure of the accounting for KV cache shifting was carried
over from the old runner but it now doesn't feel natural with the new
runner. There are a number of invariants that should hold true but
are difficult to reason about. There is at least one bug report
that would imply that the invariants are not holding.

This reduces the number of implicit assumptions and is more forgiving
of unexpected situations. It also improves behavior around which input
tokens are kept when truncation occurs.

Bug #7545
parent bebef1e5
...@@ -2,6 +2,7 @@ package main ...@@ -2,6 +2,7 @@ package main
import ( import (
"errors" "errors"
"fmt"
"log/slog" "log/slog"
"reflect" "reflect"
"time" "time"
...@@ -22,7 +23,11 @@ type InputCache struct { ...@@ -22,7 +23,11 @@ type InputCache struct {
lc *llama.Context lc *llama.Context
} }
func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache bool) *InputCache { func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache bool) (*InputCache, error) {
if kvSize/numSlots < 1 {
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
}
slots := make([]InputCacheSlot, numSlots) slots := make([]InputCacheSlot, numSlots)
for i := range slots { for i := range slots {
...@@ -37,7 +42,7 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b ...@@ -37,7 +42,7 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b
slots: slots, slots: slots,
multiUserCache: multiUserCache, multiUserCache: multiUserCache,
lc: lc, lc: lc,
} }, nil
} }
// Locking: Operations on InputCacheSlot (including finding one // Locking: Operations on InputCacheSlot (including finding one
...@@ -58,7 +63,7 @@ type InputCacheSlot struct { ...@@ -58,7 +63,7 @@ type InputCacheSlot struct {
lastUsed time.Time lastUsed time.Time
} }
func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, int, error) { func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
var slot *InputCacheSlot var slot *InputCacheSlot
var numPast int var numPast int
var err error var err error
...@@ -75,7 +80,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach ...@@ -75,7 +80,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach
slot, numPast, err = c.findBestCacheSlot(prompt) slot, numPast, err = c.findBestCacheSlot(prompt)
} }
if err != nil { if err != nil {
return nil, nil, 0, err return nil, nil, err
} }
if !cachePrompt { if !cachePrompt {
...@@ -102,7 +107,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach ...@@ -102,7 +107,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach
prompt = prompt[numPast:] prompt = prompt[numPast:]
slot.Inputs = slot.Inputs[:numPast] slot.Inputs = slot.Inputs[:numPast]
return slot, prompt, numPast, nil return slot, prompt, nil
} }
func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int, error) { func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
...@@ -194,14 +199,30 @@ func countCommonPrefix(a []input, b []input) int { ...@@ -194,14 +199,30 @@ func countCommonPrefix(a []input, b []input) int {
return count return count
} }
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int, numDiscard int, numPast int) { // Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning).
//
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) {
targetFree := (c.numCtx - numKeep) / 2
targetFree = max(targetFree, 1)
currentFree := c.numCtx - len(slot.Inputs)
discard := targetFree - currentFree
if discard <= 0 {
return
}
slog.Debug("context limit hit - shifting", "limit", c.numCtx, "input", len(slot.Inputs),
"keep", numKeep, "discard", discard)
// TODO (jessegross): KV cache removal can fail for certain types of models // TODO (jessegross): KV cache removal can fail for certain types of models
// server.cpp doesn't handle this, though we can be more graceful c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard)
c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+numDiscard) c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
c.lc.KvCacheSeqAdd(slot.Id, numKeep+numDiscard, numPast, -numDiscard)
for i := numKeep + numDiscard; i < len(slot.Inputs); i++ { for i := numKeep + discard; i < len(slot.Inputs); i++ {
slot.Inputs[i-numDiscard] = slot.Inputs[i] slot.Inputs[i-discard] = slot.Inputs[i]
} }
slot.Inputs = slot.Inputs[:len(slot.Inputs)-numDiscard] slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
} }
...@@ -34,9 +34,6 @@ type input struct { ...@@ -34,9 +34,6 @@ type input struct {
} }
type Sequence struct { type Sequence struct {
// number of inputs evaluated
numPast int
// batch index // batch index
iBatch int iBatch int
...@@ -112,21 +109,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen ...@@ -112,21 +109,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
params.numKeep = len(inputs) params.numKeep = len(inputs)
} }
if !params.embedding { if s.model.AddBOSToken() {
// Subtracting 4 ensures that at least 1 input can be discarded during shift params.numKeep += 1
params.numKeep = min(params.numKeep, s.cache.numCtx-4)
params.numKeep += s.bosToken
} else {
// Embeddings are 1 shot - just truncate to the context window, without ever shifting
params.numKeep = min(params.numKeep, s.cache.numCtx)
} }
// truncate to fit in context window // Ensure that at least 1 input can be discarded during shift
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
if len(inputs) > s.cache.numCtx { if len(inputs) > s.cache.numCtx {
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep) slog.Warn("input exceeds context length", "prompt", len(inputs), "limit", s.cache.numCtx)
newInputs := inputs[:params.numKeep]
newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...)
inputs = newInputs
} }
var sc *llama.SamplingContext var sc *llama.SamplingContext
...@@ -231,9 +222,6 @@ type Server struct { ...@@ -231,9 +222,6 @@ type Server struct {
// KV cache // KV cache
cache *InputCache cache *InputCache
// does this model require a beginning of sequence token?
bosToken int
// next sequence for prompt processing to avoid starvation // next sequence for prompt processing to avoid starvation
nextSeq int nextSeq int
...@@ -258,18 +246,6 @@ func (s *Server) allNil() bool { ...@@ -258,18 +246,6 @@ func (s *Server) allNil() bool {
return true return true
} }
func (s *Server) shiftContext(seq *Sequence) {
numLeft := seq.numPast - seq.numKeep
numDiscard := numLeft / 2
slog.Debug("context limit hit - shifting", "limit", s.cache.numCtx, "numPast", seq.numPast,
"numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard)
s.cache.ShiftCacheSlot(seq.cache, seq.numKeep, numDiscard, seq.numPast)
seq.numPast -= numDiscard
}
func flushPending(seq *Sequence) bool { func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "") joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{} seq.pendingResponses = []string{}
...@@ -374,12 +350,19 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -374,12 +350,19 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
continue continue
} }
if seq.numPast+len(seq.inputs) > s.cache.numCtx {
s.shiftContext(seq)
}
var numInputsProcessed int var numInputsProcessed int
shifted := false
for i, input := range seq.inputs { for i, input := range seq.inputs {
if len(seq.cache.Inputs)+1 > s.cache.numCtx {
if !shifted {
s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
shifted = true
} else {
break
}
}
embedding := input.embed != nil embedding := input.embed != nil
// If we don't currently have a batch, use one of the correct type and // If we don't currently have a batch, use one of the correct type and
...@@ -403,13 +386,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -403,13 +386,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
} }
crossAttention = seq.crossAttention crossAttention = seq.crossAttention
batch.Add(input.token, input.embed, seq.numPast, numInputsProcessed+1 == len(seq.inputs), seq.cache.Id) batch.Add(input.token, input.embed, len(seq.cache.Inputs), i+1 == len(seq.inputs), seq.cache.Id)
seq.numPast++ seq.cache.Inputs = append(seq.cache.Inputs, input)
numInputsProcessed++ numInputsProcessed++
} }
if numInputsProcessed > 0 { if numInputsProcessed > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.inputs[:numInputsProcessed]...)
seq.inputs = seq.inputs[numInputsProcessed:] seq.inputs = seq.inputs[numInputsProcessed:]
seq.iBatch = batch.NumTokens() - 1 seq.iBatch = batch.NumTokens() - 1
} }
...@@ -632,7 +614,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { ...@@ -632,7 +614,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
s.mu.Lock() s.mu.Lock()
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil { if err != nil {
s.mu.Unlock() s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
...@@ -715,7 +697,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { ...@@ -715,7 +697,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
s.mu.Lock() s.mu.Lock()
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil { if err != nil {
s.mu.Unlock() s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
...@@ -802,10 +784,6 @@ func (s *Server) loadModel( ...@@ -802,10 +784,6 @@ func (s *Server) loadModel(
} }
} }
if s.model.AddBOSToken() {
s.bosToken = 1
}
if ppath != "" { if ppath != "" {
var err error var err error
s.image, err = NewImageContext(s.lc, ppath) s.image, err = NewImageContext(s.lc, ppath)
...@@ -814,7 +792,10 @@ func (s *Server) loadModel( ...@@ -814,7 +792,10 @@ func (s *Server) loadModel(
} }
} }
s.cache = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache) s.cache, err = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache)
if err != nil {
panic(err)
}
s.status = ServerStatusReady s.status = ServerStatusReady
s.ready.Done() s.ready.Done()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment