wrapper.go 2.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
package kvcache

// import (
// 	"math"

// 	"github.com/ollama/ollama/ml"
// 	"github.com/ollama/ollama/model/input"
// )

// // Wrapper cache is a container for multiple types of caches,
// // such as for the encoding and decoding portions of a model.
// type WrapperCache struct {
// 	// caches we are wrapping
// 	caches []Cache

// 	// cache to be used for this layer
// 	curType int
// }

// func NewWrapperCache(caches ...Cache) *WrapperCache {
// 	return &WrapperCache{
// 		caches: caches,
// 	}
// }

// func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// 	for _, cache := range c.caches {
// 		cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
// 	}
// }

// func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
// 	for _, cache := range c.caches {
// 		cache.SetConfig(config)
// 	}
// }

// func (c *WrapperCache) Close() {
// 	for _, cache := range c.caches {
// 		cache.Close()
// 	}
// }

// func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// 	for i, cache := range c.caches {
// 		err := cache.StartForward(ctx, batch, reserve)
// 		if err != nil {
// 			// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
// 			for j := i - 1; j >= 0; j-- {
// 				for k := range batch.Positions {
// 					_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
// 				}
// 			}
// 			return err
// 		}
// 	}

// 	c.curType = 0
// 	return nil
// }

// func (c *WrapperCache) SetLayer(layer int) {
// 	for _, cache := range c.caches {
// 		cache.SetLayer(layer)
// 	}
// }

// func (c *WrapperCache) SetLayerType(layerType int) {
// 	c.curType = layerType
// }

// func (c *WrapperCache) UnderlyingCache() Cache {
// 	return c.caches[c.curType]
// }

// func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// 	return c.caches[c.curType].Get(ctx)
// }

// func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
// 	c.caches[c.curType].Put(ctx, key, value)
// }

// func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
// 	for _, cache := range c.caches {
// 		cache.CopyPrefix(srcSeq, dstSeq, len)
// 	}
// }

// func (c *WrapperCache) CanResume(seq int, pos int32) bool {
// 	for _, cache := range c.caches {
// 		if !cache.CanResume(seq, pos) {
// 			return false
// 		}
// 	}

// 	return true
// }

// func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
// 	// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
// 	for _, cache := range c.caches {
// 		err := cache.Remove(seq, beginIndex, endIndex)
// 		if err != nil {
// 			return err
// 		}
// 	}

// 	return nil
// }