tensor.go 3.06 KB
Newer Older
1
2
3
package convert

import (
Michael Yang's avatar
Michael Yang committed
4
	"cmp"
5
	"io"
6
	"iter"
7
	"path"
8
9
10
11
12
	"slices"
	"strings"

	"github.com/pdevine/tensor"
	"github.com/pdevine/tensor/native"
Michael Yang's avatar
Michael Yang committed
13
14

	"github.com/ollama/ollama/fs/ggml"
15
16
)

Michael Yang's avatar
Michael Yang committed
17
18
type split struct {
	*strings.Replacer
Michael Yang's avatar
Michael Yang committed
19
20
	dim    int
	slices []tensor.Slice
Michael Yang's avatar
Michael Yang committed
21
22
23
24
25

	// fn is an optional function to apply to the tensor after slicing
	fn func(tensor.Tensor) (tensor.Tensor, error)
}

26
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
Michael Yang's avatar
Michael Yang committed
27
28
// is split evenly based on the number of replacers provided unless a specific count is given.
func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
29
	return func(yield func(*ggml.Tensor) bool) {
Michael Yang's avatar
Michael Yang committed
30
31
32
		var offset int
		for _, split := range splits {
			t := t.Clone()
33
			shape := slices.Clone(t.Shape())
Michael Yang's avatar
Michael Yang committed
34
			shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
35

Michael Yang's avatar
Michael Yang committed
36
37
38
39
40
41
			slice := split.slices
			if len(slice) == 0 {
				slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
				slice[dim] = tensor.S(offset, offset+int(shape[dim]))
				offset += int(shape[dim])
			}
42

Michael Yang's avatar
Michael Yang committed
43
			t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
44
45
46
47
48
				dims := make([]int, len(shape))
				for i := range shape {
					dims[i] = int(shape[i])
				}

Michael Yang's avatar
Michael Yang committed
49
50
				var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
				tt, err := tt.Slice(slice...)
51
52
53
54
				if err != nil {
					return nil, err
				}

Michael Yang's avatar
Michael Yang committed
55
56
57
58
59
60
61
62
63
				tt = tensor.Materialize(tt)

				if split.fn != nil {
					tt, err = split.fn(tt)
					if err != nil {
						return nil, err
					}
				}

64
				// flatten tensor so it can be written as a vector
Michael Yang's avatar
Michael Yang committed
65
				if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
66
67
68
					return nil, err
				}

Michael Yang's avatar
Michael Yang committed
69
				return native.VectorF32(tt.(*tensor.Dense))
70
71
72
			})

			if !yield(&ggml.Tensor{
Michael Yang's avatar
Michael Yang committed
73
				Name:     split.Replace(t.Name()),
74
75
				Kind:     t.Kind(),
				Shape:    shape,
Michael Yang's avatar
Michael Yang committed
76
				WriterTo: t,
77
78
79
80
81
82
			}) {
				break
			}
		}
	}
}
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

type merge struct {
	pattern, name string
}

// mergeTensors merges tensors that match a given pattern into a single tensor.
func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []Tensor) {
	var matched []Tensor
	for i := range merges {
		matched, unmatched = slicesSplitFunc(unmatched, func(t Tensor) bool {
			matched, _ := path.Match(merges[i].pattern, t.Name())
			return matched
		})

		if len(matched) > 0 {
			out = append(out, &ggml.Tensor{
				Name:     merges[i].name,
				Kind:     matched[0].Kind(),
				Shape:    append([]uint64{uint64(len(matched))}, matched[0].Shape()...),
				WriterTo: mergeGroup(matched),
			})
		}
	}

	return out, unmatched
}

// slicesSplitFunc splits a slice into two slices based on a predicate function.
func slicesSplitFunc[S ~[]E, E comparable](s S, fn func(e E) bool) (matched, unmatched S) {
	for _, e := range s {
		if fn(e) {
			matched = append(matched, e)
		} else {
			unmatched = append(unmatched, e)
		}
	}

	return matched, unmatched
}

type mergeGroup []Tensor

func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
	for _, t := range g {
		if _, err := t.WriteTo(w); err != nil {
			return 0, err
		}
	}

	return 0, nil
}