wrapper.go 2.06 KB
Newer Older
Jesse Gross's avatar
Jesse Gross committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package kvcache

import (
	"math"

	"github.com/ollama/ollama/ml"
)

// Wrapper cache is a container for multiple types of caches,
// such as for the encoding and decoding portions of a model.
type WrapperCache struct {
	// caches we are wrapping
	caches []Cache

	// cache to be used for this layer
	curType int
}

func NewWrapperCache(caches ...Cache) *WrapperCache {
	return &WrapperCache{
		caches: caches,
	}
}

func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
	for _, cache := range c.caches {
		cache.Init(backend, dtype, capacity)
	}
}

func (c *WrapperCache) Close() {
	for _, cache := range c.caches {
		cache.Close()
	}
}

func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
	for i, cache := range c.caches {
		err := cache.StartForward(ctx, positions, seqs)
		if err != nil {
			// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
			for j := i - 1; j >= 0; j-- {
				for k := range positions {
					_ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32)
				}
			}
			return err
		}
	}

	c.curType = 0
	return nil
}

func (c *WrapperCache) SetLayer(layer int) {
	for _, cache := range c.caches {
		cache.SetLayer(layer)
	}
}

func (c *WrapperCache) SetLayerType(layerType int) {
	c.curType = layerType
}

func (c *WrapperCache) UnderlyingCache() Cache {
	return c.caches[c.curType]
}

func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
	return c.caches[c.curType].Get(ctx)
}

func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
	c.caches[c.curType].Put(ctx, key, value)
}

func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
	for _, cache := range c.caches {
		cache.CopyPrefix(srcSeq, dstSeq, len)
	}
}

func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
	// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
	for _, cache := range c.caches {
		err := cache.Remove(seq, beginIndex, endIndex)
		if err != nil {
			return err
		}
	}

	return nil
}