tensor.go 1.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
package convert

import (
	"iter"
	"slices"
	"strings"

	"github.com/ollama/ollama/fs/ggml"
	"github.com/pdevine/tensor"
	"github.com/pdevine/tensor/native"
)

// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
// is split evenly based on the number of replacers provided.
func splitDim(t Tensor, dim int, replacers ...*strings.Replacer) iter.Seq[*ggml.Tensor] {
	return func(yield func(*ggml.Tensor) bool) {
		for i, replacer := range replacers {
			shape := slices.Clone(t.Shape())
			shape[dim] = shape[dim] / uint64(len(replacers))

			slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
			slice[dim] = tensor.S(i*int(shape[dim]), (i+1)*int(shape[dim]))

			tt := t.Clone()
			tt.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
				dims := make([]int, len(shape))
				for i := range shape {
					dims[i] = int(shape[i])
				}

				var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
				t, err := t.Slice(slice...)
				if err != nil {
					return nil, err
				}

				t = tensor.Materialize(t)
				// flatten tensor so it can be written as a vector
				if err := t.Reshape(t.Shape().TotalSize()); err != nil {
					return nil, err
				}

				return native.VectorF32(t.(*tensor.Dense))
			})

			if !yield(&ggml.Tensor{
				Name:     replacer.Replace(t.Name()),
				Kind:     t.Kind(),
				Shape:    shape,
				WriterTo: tt,
			}) {
				break
			}
		}
	}
}