cache.go 2.52 KB
Newer Older
Jesse Gross's avatar
Jesse Gross committed
1
2
3
4
5
6
package kvcache

import (
	"errors"

	"github.com/ollama/ollama/ml"
7
	"github.com/ollama/ollama/model/input"
Jesse Gross's avatar
Jesse Gross committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
)

var (
	ErrKvCacheFull  = errors.New("could not find a kv cache slot")
	ErrNotSupported = errors.New("model does not support operation")
)

type Cache interface {
	// ** used by model implementations **

	// SetLayer sets the active layer of the cache
	SetLayer(layer int)

	// Get returns the history of key and value tensors plus a mask
	//
	// The shape of the tensors is documented in the specific
	// cache implementation used.
	Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)

	// Put stores a batch of key and value in the cache
	//
	// The shape of the tensors is documented in the specific
	// cache implementation used.
	Put(ctx ml.Context, key, value ml.Tensor)

33
34
35
36
37
38
39
40
41
42
43
	// SetConfig controls optimizations (mostly backend-specific) that may transform
	// the output of the cache to work better with specific kernels. If not called,
	// the backend settings will be used. This works well when calling Attention.
	//
	// The config can be overridden by models, especially if they require vanilla
	// output when implementing their own version of attention. To do this, pass
	// an empty ml.CacheConfig.
	//
	// Most models will not need to use this.
	SetConfig(ml.CacheConfig)

Jesse Gross's avatar
Jesse Gross committed
44
45
	// ** cache management **

46
47
48
49
50
51
52
	// Init sets up runtime parameters.
	// backend: Used to allocate cache data storage and execute management operations (such as defrag)
	// dtype: The data type for storing cache entries
	// maxSequences: The maximum number of sequences stored in the cache - across all batches
	// capacity: The number of cache entries to store, per sequence
	// maxBatch: The maximum number of tokens that can occur in a single batch
	Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
Jesse Gross's avatar
Jesse Gross committed
53
54
55
56
57
58
59

	// Close closes the cache and frees resources associated with it
	Close()

	// StartForward is called before the start of the model's forward pass.
	// For each token in the coming batch, there must be a corresponding
	// entry in positions and seqs.
Jesse Gross's avatar
Jesse Gross committed
60
	StartForward(ctx ml.Context, batch input.Batch) error
Jesse Gross's avatar
Jesse Gross committed
61
62
63
64
65
66
67
68
69
70
71

	// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
	CopyPrefix(srcSeq, dstSeq int, len int32)

	// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
	// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
	//
	// If an error occurs, the entire context for the sequence should be
	// removed by calling Remove(seq, 0, math.MaxInt32)
	Remove(seq int, beginIndex, endIndex int32) error
}