encoder.go 3.11 KB
Newer Older
Jesse Gross's avatar
Jesse Gross committed
1
2
3
package kvcache

import (
4
5
	"fmt"

Jesse Gross's avatar
Jesse Gross committed
6
7
8
9
10
11
12
13
14
15
	"github.com/ollama/ollama/ml"
)

// Encoder cache stores K and V tensors that are position independent
//
// The tensors can be of any shape and will be returned as they were stored
// The mask is currently always nil
//
// Not currently safe for multiple sequences
type EncoderCache struct {
16
17
18
	// config controls mostly backend-specific optimizations
	config *ml.CacheConfig

Jesse Gross's avatar
Jesse Gross committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
	// ** current forward pass **

	// the active layer for Get and Put
	curLayer int

	// if something is stored during this pass, this
	// will be the position (but there is no guarantee
	// anything will be stored)
	curPos int32

	// ** cache metadata **

	// was something stored in the cache?
	encoderCached bool

	// position of the cached data
	encoderPos int32

	// ** cache data storage **
38
39
40
	backend      ml.Backend
	ctxs         map[int]ml.Context
	keys, values map[int]ml.Tensor
Jesse Gross's avatar
Jesse Gross committed
41
42
43
}

func NewEncoderCache() *EncoderCache {
44
45
46
47
48
	return &EncoderCache{
		ctxs:   make(map[int]ml.Context),
		keys:   make(map[int]ml.Tensor),
		values: make(map[int]ml.Tensor),
	}
Jesse Gross's avatar
Jesse Gross committed
49
50
51
}

func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
52
53
54
55
56
57
58
59
60
61
62
63
	if c.config == nil {
		var config ml.CacheConfig
		if cc, ok := backend.(ml.BackendCacheConfig); ok {
			config = cc.CacheConfig()
		}
		c.config = &config
	}

	if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
		panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
	}

64
	c.backend = backend
Jesse Gross's avatar
Jesse Gross committed
65
66
}

67
68
69
70
71
72
73
74
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
	if c.config != nil {
		panic("config cannot be changed after being previously set, either by the model or backend")
	}

	c.config = &config
}

Jesse Gross's avatar
Jesse Gross committed
75
func (c *EncoderCache) Close() {
76
77
78
	for _, ctx := range c.ctxs {
		ctx.Close()
	}
Jesse Gross's avatar
Jesse Gross committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
}

func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
	// The image is always in the first position
	c.curPos = positions[0]

	return nil
}

func (c *EncoderCache) SetLayer(layer int) {
	c.curLayer = layer
}

func (c *EncoderCache) EncoderCached() bool {
	return c.encoderCached
}

func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
	return c.keys[c.curLayer], c.values[c.curLayer], nil
}

func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
	c.encoderPos = c.curPos
	c.encoderCached = true

104
105
106
107
	if c.config.PermutedV {
		value = value.Permute(ctx, 1, 2, 0, 3)
	}

108
109
110
111
112
113
114
115
116
117
	if _, ok := c.ctxs[c.curLayer]; !ok {
		c.ctxs[c.curLayer] = c.backend.NewContext()
	}

	if _, ok := c.keys[c.curLayer]; !ok {
		c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
	}

	if _, ok := c.values[c.curLayer]; !ok {
		c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
Jesse Gross's avatar
Jesse Gross committed
118
119
	}

120
121
122
123
	ctx.Forward(
		key.Copy(ctx, c.keys[c.curLayer]),
		value.Copy(ctx, c.values[c.curLayer]),
	)
Jesse Gross's avatar
Jesse Gross committed
124
125
126
127
128
129
130
131
132
133
134
135
136
}

func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
	panic("encoder cache does not support multiple sequences")
}

func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
	if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
		c.encoderCached = false
	}

	return nil
}