cache.go 1.43 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
package cache

import (
	"github.com/ollama/ollama/ml"
)

type Options struct {
	Position int
}

type Cache interface {
	Sub(i int) Cache
	Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor)
}

type Simple struct {
	DType    ml.DType
	Capacity int

	keys, values []ml.Tensor
}

func (c *Simple) Sub(i int) Cache {
	if i >= len(c.keys) {
		c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...)
		c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...)
	}

	return &Simple{
		keys:     c.keys[i : i+1],
		values:   c.values[i : i+1],
		Capacity: c.Capacity,
		DType:    c.DType,
	}
}

func (c *Simple) Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor) {
	if c.keys[0] == nil || c.values[0] == nil {
39
40
		c.keys[0] = ctx.Zeros(c.DType, key.Dim(0)*key.Dim(1)*c.Capacity)
		c.values[0] = ctx.Zeros(c.DType, value.Dim(0)*value.Dim(1)*c.Capacity)
Michael Yang's avatar
Michael Yang committed
41
42
	}

43
44
	ctx.Forward(key.Copy(ctx, c.keys[0].View(ctx, key.Stride(2)*opts.Position, key.Dim(0)*key.Dim(1)*key.Dim(2))))
	ctx.Forward(value.Copy(ctx, c.values[0].View(ctx, value.Stride(2)*opts.Position, value.Dim(0)*value.Dim(1)*value.Dim(2))))
Michael Yang's avatar
Michael Yang committed
45

46
	n := min(c.Capacity, key.Dim(2)+opts.Position)
Michael Yang's avatar
Michael Yang committed
47
48

	key = c.keys[0].View(ctx, 0,
49
50
		key.Dim(0), key.Stride(1),
		key.Dim(1), key.Stride(2),
Michael Yang's avatar
Michael Yang committed
51
52
53
54
		n,
	)

	value = c.values[0].View(ctx, 0,
55
56
		value.Dim(0), value.Stride(1),
		value.Dim(1), value.Stride(2),
Michael Yang's avatar
Michael Yang committed
57
58
59
60
61
62
63
		n,
	)

	// TODO shift context if necessary

	return key, value
}