tensor.go 1.79 KB
Newer Older
1
2
3
package convert

import (
Michael Yang's avatar
Michael Yang committed
4
	"cmp"
5
6
7
8
9
10
	"iter"
	"slices"
	"strings"

	"github.com/pdevine/tensor"
	"github.com/pdevine/tensor/native"
Michael Yang's avatar
Michael Yang committed
11
12

	"github.com/ollama/ollama/fs/ggml"
13
14
)

Michael Yang's avatar
Michael Yang committed
15
16
17
18
19
20
21
22
type split struct {
	*strings.Replacer
	dim int

	// fn is an optional function to apply to the tensor after slicing
	fn func(tensor.Tensor) (tensor.Tensor, error)
}

23
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
Michael Yang's avatar
Michael Yang committed
24
25
// is split evenly based on the number of replacers provided unless a specific count is given.
func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
26
	return func(yield func(*ggml.Tensor) bool) {
Michael Yang's avatar
Michael Yang committed
27
28
29
		var offset int
		for _, split := range splits {
			t := t.Clone()
30
			shape := slices.Clone(t.Shape())
Michael Yang's avatar
Michael Yang committed
31
			shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
32
33

			slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
Michael Yang's avatar
Michael Yang committed
34
35
			slice[dim] = tensor.S(offset, offset+int(shape[dim]))
			offset += int(shape[dim])
36

Michael Yang's avatar
Michael Yang committed
37
			t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
38
39
40
41
42
				dims := make([]int, len(shape))
				for i := range shape {
					dims[i] = int(shape[i])
				}

Michael Yang's avatar
Michael Yang committed
43
44
				var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
				tt, err := tt.Slice(slice...)
45
46
47
48
				if err != nil {
					return nil, err
				}

Michael Yang's avatar
Michael Yang committed
49
50
51
52
53
54
55
56
57
				tt = tensor.Materialize(tt)

				if split.fn != nil {
					tt, err = split.fn(tt)
					if err != nil {
						return nil, err
					}
				}

58
				// flatten tensor so it can be written as a vector
Michael Yang's avatar
Michael Yang committed
59
				if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
60
61
62
					return nil, err
				}

Michael Yang's avatar
Michael Yang committed
63
				return native.VectorF32(tt.(*tensor.Dense))
64
65
66
			})

			if !yield(&ggml.Tensor{
Michael Yang's avatar
Michael Yang committed
67
				Name:     split.Replace(t.Name()),
68
69
				Kind:     t.Kind(),
				Shape:    shape,
Michael Yang's avatar
Michael Yang committed
70
				WriterTo: t,
71
72
73
74
75
76
			}) {
				break
			}
		}
	}
}