cache.go 7.87 KB
Newer Older
Jesse Gross's avatar
Jesse Gross committed
1
package ollamarunner
2
3
4

import (
	"errors"
5
	"fmt"
6
	"log/slog"
Jesse Gross's avatar
Jesse Gross committed
7
	"math"
8
9
	"time"

Jesse Gross's avatar
Jesse Gross committed
10
11
12
	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/model"
13
	"github.com/ollama/ollama/model/input"
14
15
16
17
)

type InputCache struct {
	// context window size (per slot)
Jesse Gross's avatar
Jesse Gross committed
18
19
20
21
22
23
	numCtx int32

	// does the cache store data or do we need to always send the full input?
	// note that when enabled is false the underlying cache may either be nil
	// or a non-nil dummy that doesn't actually store anything
	enabled bool
24
25
26
27
28
29
30

	// individual KV caches
	slots []InputCacheSlot

	// optimize cache eviction for multiple users
	multiUserCache bool

Jesse Gross's avatar
Jesse Gross committed
31
	cache kvcache.Cache
32
33
}

34
35
36
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
	numCtx := kvSize / int32(numSlots)

37
38
	if int(numCtx) < batchSize {
		return nil, fmt.Errorf("kv size must be at least as large as batch size * parallel (kv: %v batch: %v parallel: %v)", kvSize, batchSize, numSlots)
39
40
	}

41
42
43
	slots := make([]InputCacheSlot, numSlots)

	for i := range slots {
44
		slots[i] = InputCacheSlot{Id: i}
45
46
	}

Jesse Gross's avatar
Jesse Gross committed
47
48
	cache := model.Config().Cache
	if cache != nil {
49
		cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize)
Jesse Gross's avatar
Jesse Gross committed
50
51
	}

52
	return &InputCache{
53
		numCtx:         numCtx,
Jesse Gross's avatar
Jesse Gross committed
54
		enabled:        cache != nil,
55
56
		slots:          slots,
		multiUserCache: multiUserCache,
Jesse Gross's avatar
Jesse Gross committed
57
		cache:          cache,
58
	}, nil
59
60
}

Jesse Gross's avatar
Jesse Gross committed
61
62
63
func kvCacheTypeFromStr(s string) ml.DType {
	switch s {
	case "q8_0":
64
		return ml.DTypeQ80
Jesse Gross's avatar
Jesse Gross committed
65
	case "q4_0":
66
		return ml.DTypeQ40
Jesse Gross's avatar
Jesse Gross committed
67
68
69
70
71
72
	default:
		return ml.DTypeF16
	}
}

func (c *InputCache) Close() {
73
74
	if c != nil && c.cache != nil {
		c.cache.Close()
Jesse Gross's avatar
Jesse Gross committed
75
	}
Jesse Gross's avatar
Jesse Gross committed
76
77
}

78
// Locking: Operations on InputCacheSlot (including finding one
79
// through LoadCacheSlot) require a lock to be held that serializes
Jesse Gross's avatar
Jesse Gross committed
80
// these operations with each other and processBatch
81
82
83
84
85
86

type InputCacheSlot struct {
	// Index in the KV cache
	Id int

	// Inputs that are stored in the KV cache
87
	Inputs []*input.Input
88
89
90
91
92
93
94
95

	// is this cache actively being processed as part of a sequence?
	InUse bool

	// last time this cache was used (as of start of processing)
	lastUsed time.Time
}

Michael Yang's avatar
Michael Yang committed
96
func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*InputCacheSlot, []*input.Input, error) {
97
	var slot *InputCacheSlot
Jesse Gross's avatar
Jesse Gross committed
98
	var numPast int32
99
100
101
	var err error

	// In single-user scenarios, the longest cache slot works fine for getting good input
Jesse Gross's avatar
Jesse Gross committed
102
	// cache hit rates and it keeps the footprint of the cache small, which improves throughput.
103
	// For multiple users, the "best" cache slot produces better input cache hit rates
Jesse Gross's avatar
Jesse Gross committed
104
	// at the cost of worse performance when we miss the input cache.
105
106
107
108
109
110
	if !c.multiUserCache {
		slot, numPast, err = c.findLongestCacheSlot(prompt)
	} else {
		slot, numPast, err = c.findBestCacheSlot(prompt)
	}
	if err != nil {
111
		return nil, nil, err
112
113
	}

Michael Yang's avatar
Michael Yang committed
114
115
116
117
	if !cachePrompt {
		numPast = 0
	}

118
119
120
	slot.InUse = true
	slot.lastUsed = time.Now()

Jesse Gross's avatar
Jesse Gross committed
121
	if numPast == int32(len(prompt)) {
122
123
124
125
		// Leave one input to sample so we can get a response
		numPast--
	}

Jesse Gross's avatar
Jesse Gross committed
126
	if c.cache != nil {
127
128
129
130
		if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) {
			numPast = 0
		}

Jesse Gross's avatar
Jesse Gross committed
131
132
133
134
135
136
137
138
139
		err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
		if err != nil {
			// Some models don't support partial erasure
			err = c.cache.Remove(slot.Id, 0, math.MaxInt32)
			if err != nil {
				return nil, nil, err
			}
			numPast = 0
		}
140
141
142
	}

	slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
Jesse Gross's avatar
Jesse Gross committed
143
		"used", numPast, "remaining", int32(len(prompt))-numPast)
144

145
	slot.Inputs = prompt[:numPast]
146
147
	prompt = prompt[numPast:]

148
	return slot, prompt, nil
149
150
}

151
func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
Jesse Gross's avatar
Jesse Gross committed
152
	longest := int32(-1)
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
	var longestSlot *InputCacheSlot

	for i, s := range c.slots {
		if s.InUse {
			continue
		}

		count := countCommonPrefix(s.Inputs, prompt)
		if count > longest {
			longest = count
			longestSlot = &c.slots[i]
		}
	}

	if longestSlot == nil {
		return nil, 0, errors.New("no available cache slots")
	}

	return longestSlot, longest, nil
}

174
func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
175
176
177
	oldest := time.Now()
	var oldestSlot *InputCacheSlot

Jesse Gross's avatar
Jesse Gross committed
178
	longest := int32(-1)
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
	var longestSlot *InputCacheSlot

	for i, s := range c.slots {
		count := countCommonPrefix(s.Inputs, prompt)
		if count > longest {
			longest = count
			longestSlot = &c.slots[i]
		}

		if s.lastUsed.Compare(oldest) < 0 && !s.InUse {
			oldest = s.lastUsed
			oldestSlot = &c.slots[i]
		}
	}

Jesse Gross's avatar
Jesse Gross committed
194
	if longest == int32(len(longestSlot.Inputs)) && !longestSlot.InUse {
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
		return longestSlot, longest, nil
	}

	if oldestSlot.InUse {
		return nil, 0, errors.New("no available cache slots")
	}

	if len(oldestSlot.Inputs) != 0 {
		slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs),
			"used", oldestSlot.lastUsed)
	}

	if longest > 0 && longestSlot != oldestSlot {
		slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
			len(longestSlot.Inputs))
210
		oldestSlot.Inputs = make([]*input.Input, longest)
211
		copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
Jesse Gross's avatar
Jesse Gross committed
212
213
		if c.cache != nil {
			c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
214
215
216
217
218
219
		}
	}

	return oldestSlot, longest, nil
}

220
func countCommonPrefix(a []*input.Input, b []*input.Input) int32 {
Jesse Gross's avatar
Jesse Gross committed
221
	var count int32
222
223
224
225
226
227

	for i := range a {
		if i >= len(b) {
			break
		}

228
		if a[i].Token != b[i].Token || a[i].MultimodalHash != b[i].MultimodalHash {
229
230
231
232
233
234
235
236
237
			break
		}

		count++
	}

	return count
}

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
// ShiftDiscard computes how many inputs can be discarded from the cache. Inputs in the same batch
// are discarded together.
func (c *InputCache) ShiftDiscard(inputs []*input.Input, numKeep int32) int32 {
	targetFree := max((c.numCtx-numKeep)/2, 1)
	currentFree := c.numCtx - int32(len(inputs))

	var discard, sameBatch int32
	for _, input := range inputs[numKeep:] {
		if sameBatch <= 0 && currentFree >= targetFree {
			break
		}

		sameBatch--
		currentFree++
		discard++
253

254
255
256
257
		if input.SameBatch > 0 {
			sameBatch = int32(input.SameBatch)
		}
	}
258

259
	return discard
260
261
}

262
type ErrReprocessInputs struct {
263
	Inputs []*input.Input
264
265
266
267
268
269
}

func (e *ErrReprocessInputs) Error() string {
	return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
}

270
271
272
273
// Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning).
//
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
Jesse Gross's avatar
Jesse Gross committed
274
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
275
276
277
278
	if numKeep >= c.numCtx {
		return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
	}

Jesse Gross's avatar
Jesse Gross committed
279
	inputLen := int32(len(slot.Inputs))
280
	discard := c.ShiftDiscard(slot.Inputs, numKeep)
281
282

	if discard <= 0 {
283
		return nil
284
285
	}

286
	slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
287
288
		"keep", numKeep, "discard", discard)

Jesse Gross's avatar
Jesse Gross committed
289
290
291
	if c.cache != nil {
		err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
		if err != nil {
292
293
294
295
			slog.Debug("kv cache removal unsupported, clearing cache and returning inputs for reprocessing",
				"id", slot.Id, "error", err)

			// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
296
			newInputs := make([]*input.Input, numKeep+inputLen-(numKeep+discard))
297
298
299
300
			copy(newInputs[:numKeep], slot.Inputs[:numKeep])
			copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])

			// Reset the cache
301
			_ = c.cache.Remove(slot.Id, 0, math.MaxInt32)
302
			slot.Inputs = []*input.Input{}
303
304
305

			// Return error with inputs that need to be reprocessed
			return &ErrReprocessInputs{Inputs: newInputs}
Jesse Gross's avatar
Jesse Gross committed
306
		}
307
	}
308

Jesse Gross's avatar
Jesse Gross committed
309
	for i := numKeep + discard; i < inputLen; i++ {
310
		slot.Inputs[i-discard] = slot.Inputs[i]
311
	}
Jesse Gross's avatar
Jesse Gross committed
312
	slot.Inputs = slot.Inputs[:inputLen-discard]
313
314

	return nil
315
}