encoder.go 3.01 KB
Newer Older
Jesse Gross's avatar
Jesse Gross committed
1
2
3
package kvcache

import (
4
5
	"fmt"

Jesse Gross's avatar
Jesse Gross committed
6
7
8
9
10
11
12
13
14
15
	"github.com/ollama/ollama/ml"
)

// Encoder cache stores K and V tensors that are position independent
//
// The tensors can be of any shape and will be returned as they were stored
// The mask is currently always nil
//
// Not currently safe for multiple sequences
type EncoderCache struct {
16
17
18
	// config controls mostly backend-specific optimizations
	config *ml.CacheConfig

Jesse Gross's avatar
Jesse Gross committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
	// ** current forward pass **

	// the active layer for Get and Put
	curLayer int

	// if something is stored during this pass, this
	// will be the position (but there is no guarantee
	// anything will be stored)
	curPos int32

	// ** cache metadata **

	// was something stored in the cache?
	encoderCached bool

	// position of the cached data
	encoderPos int32

	// ** cache data storage **

	cacheCtx     ml.Context
	keys, values []ml.Tensor
}

func NewEncoderCache() *EncoderCache {
	return &EncoderCache{}
}

func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
48
49
50
51
52
53
54
55
56
57
58
59
	if c.config == nil {
		var config ml.CacheConfig
		if cc, ok := backend.(ml.BackendCacheConfig); ok {
			config = cc.CacheConfig()
		}
		c.config = &config
	}

	if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
		panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
	}

Jesse Gross's avatar
Jesse Gross committed
60
61
62
	c.cacheCtx = backend.NewContext()
}

63
64
65
66
67
68
69
70
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
	if c.config != nil {
		panic("config cannot be changed after being previously set, either by the model or backend")
	}

	c.config = &config
}

Jesse Gross's avatar
Jesse Gross committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
func (c *EncoderCache) Close() {
	c.cacheCtx.Close()
}

func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
	// The image is always in the first position
	c.curPos = positions[0]

	return nil
}

func (c *EncoderCache) SetLayer(layer int) {
	if layer >= len(c.keys) {
		c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
		c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
	}

	c.curLayer = layer
}

func (c *EncoderCache) EncoderCached() bool {
	return c.encoderCached
}

func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
	return c.keys[c.curLayer], c.values[c.curLayer], nil
}

func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
	c.encoderPos = c.curPos
	c.encoderCached = true

103
104
105
106
	if c.config.PermutedV {
		value = value.Permute(ctx, 1, 2, 0, 3)
	}

Jesse Gross's avatar
Jesse Gross committed
107
	if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
108
109
		c.keys[c.curLayer] = c.cacheCtx.Empty(key.DType(), key.Shape()...)
		c.values[c.curLayer] = c.cacheCtx.Empty(value.DType(), value.Shape()...)
Jesse Gross's avatar
Jesse Gross committed
110
111
	}

112
113
114
115
	ctx.Forward(
		key.Copy(ctx, c.keys[c.curLayer]),
		value.Copy(ctx, c.values[c.curLayer]),
	)
Jesse Gross's avatar
Jesse Gross committed
116
117
118
119
120
121
122
123
124
125
126
127
128
}

func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
	panic("encoder cache does not support multiple sequences")
}

func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
	if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
		c.encoderCached = false
	}

	return nil
}