wrapper.go 2.22 KB
Newer Older
Jesse Gross's avatar
Jesse Gross committed
1
2
3
4
5
6
package kvcache

import (
	"math"

	"github.com/ollama/ollama/ml"
7
	"github.com/ollama/ollama/model/input"
Jesse Gross's avatar
Jesse Gross committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
)

// Wrapper cache is a container for multiple types of caches,
// such as for the encoding and decoding portions of a model.
type WrapperCache struct {
	// caches we are wrapping
	caches []Cache

	// cache to be used for this layer
	curType int
}

func NewWrapperCache(caches ...Cache) *WrapperCache {
	return &WrapperCache{
		caches: caches,
	}
}

func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
	for _, cache := range c.caches {
		cache.Init(backend, dtype, capacity)
	}
}

32
33
34
35
36
37
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
	for _, cache := range c.caches {
		cache.SetConfig(config)
	}
}

Jesse Gross's avatar
Jesse Gross committed
38
39
40
41
42
43
func (c *WrapperCache) Close() {
	for _, cache := range c.caches {
		cache.Close()
	}
}

44
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
Jesse Gross's avatar
Jesse Gross committed
45
	for i, cache := range c.caches {
46
		err := cache.StartForward(ctx, opts)
Jesse Gross's avatar
Jesse Gross committed
47
48
49
		if err != nil {
			// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
			for j := i - 1; j >= 0; j-- {
50
51
				for k := range opts.Positions {
					_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
Jesse Gross's avatar
Jesse Gross committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
				}
			}
			return err
		}
	}

	c.curType = 0
	return nil
}

func (c *WrapperCache) SetLayer(layer int) {
	for _, cache := range c.caches {
		cache.SetLayer(layer)
	}
}

func (c *WrapperCache) SetLayerType(layerType int) {
	c.curType = layerType
}

func (c *WrapperCache) UnderlyingCache() Cache {
	return c.caches[c.curType]
}

func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
	return c.caches[c.curType].Get(ctx)
}

func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
	c.caches[c.curType].Put(ctx, key, value)
}

func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
	for _, cache := range c.caches {
		cache.CopyPrefix(srcSeq, dstSeq, len)
	}
}

func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
	// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
	for _, cache := range c.caches {
		err := cache.Remove(seq, beginIndex, endIndex)
		if err != nil {
			return err
		}
	}

	return nil
}