tensor.go 3.59 KB
Newer Older
1
2
3
package convert

import (
Michael Yang's avatar
Michael Yang committed
4
	"cmp"
Michael Yang's avatar
Michael Yang committed
5
	"errors"
6
	"io"
7
	"iter"
8
	"path"
9
	"slices"
Michael Yang's avatar
Michael Yang committed
10
	"strconv"
11
12
13
14
	"strings"

	"github.com/pdevine/tensor"
	"github.com/pdevine/tensor/native"
Michael Yang's avatar
Michael Yang committed
15
16

	"github.com/ollama/ollama/fs/ggml"
17
18
)

Michael Yang's avatar
Michael Yang committed
19
20
type split struct {
	*strings.Replacer
Michael Yang's avatar
Michael Yang committed
21
22
	dim    int
	slices []tensor.Slice
Michael Yang's avatar
Michael Yang committed
23

24
25
	// afterFunc is an optional function to apply to the tensor after slicing
	afterFunc func(tensor.Tensor) (tensor.Tensor, error)
Michael Yang's avatar
Michael Yang committed
26
27
}

28
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
Michael Yang's avatar
Michael Yang committed
29
30
// is split evenly based on the number of replacers provided unless a specific count is given.
func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
31
	return func(yield func(*ggml.Tensor) bool) {
Michael Yang's avatar
Michael Yang committed
32
33
34
		var offset int
		for _, split := range splits {
			t := t.Clone()
35
			shape := slices.Clone(t.Shape())
Michael Yang's avatar
Michael Yang committed
36
			shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
37

Michael Yang's avatar
Michael Yang committed
38
39
			slice := split.slices
			if len(slice) == 0 {
Michael Yang's avatar
Michael Yang committed
40
				slice = slices.Repeat([]tensor.Slice{nil}, len(shape))
Michael Yang's avatar
Michael Yang committed
41
42
43
				slice[dim] = tensor.S(offset, offset+int(shape[dim]))
				offset += int(shape[dim])
			}
44

Michael Yang's avatar
Michael Yang committed
45
			t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
46
47
48
49
50
				dims := make([]int, len(shape))
				for i := range shape {
					dims[i] = int(shape[i])
				}

Michael Yang's avatar
Michael Yang committed
51
52
				var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
				tt, err := tt.Slice(slice...)
53
54
55
56
				if err != nil {
					return nil, err
				}

Michael Yang's avatar
Michael Yang committed
57
58
				tt = tensor.Materialize(tt)

59
60
				if split.afterFunc != nil {
					tt, err = split.afterFunc(tt)
Michael Yang's avatar
Michael Yang committed
61
62
63
64
65
					if err != nil {
						return nil, err
					}
				}

66
				// flatten tensor so it can be written as a vector
Michael Yang's avatar
Michael Yang committed
67
				if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
68
69
70
					return nil, err
				}

Michael Yang's avatar
Michael Yang committed
71
				return native.VectorF32(tt.(*tensor.Dense))
72
73
74
			})

			if !yield(&ggml.Tensor{
Michael Yang's avatar
Michael Yang committed
75
				Name:     split.Replace(t.Name()),
76
77
				Kind:     t.Kind(),
				Shape:    shape,
Michael Yang's avatar
Michael Yang committed
78
				WriterTo: t,
79
80
81
82
83
84
			}) {
				break
			}
		}
	}
}
85
86
87
88
89
90
91
92
93
94
95
96
97
98

type merge struct {
	pattern, name string
}

// mergeTensors merges tensors that match a given pattern into a single tensor.
func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []Tensor) {
	var matched []Tensor
	for i := range merges {
		matched, unmatched = slicesSplitFunc(unmatched, func(t Tensor) bool {
			matched, _ := path.Match(merges[i].pattern, t.Name())
			return matched
		})

Michael Yang's avatar
Michael Yang committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
		slices.SortStableFunc(matched, func(a, b Tensor) int {
			x := strings.Split(a.Name(), ".")
			y := strings.Split(b.Name(), ".")
			if len(x) != len(y) {
				return cmp.Compare(len(x), len(y))
			}

			vals := make([]int, len(x))
			for i := range x {
				vals[i] = strings.Compare(x[i], y[i])
				m, err := strconv.ParseInt(x[i], 0, 0)
				n, err2 := strconv.ParseInt(y[i], 0, 0)
				if errors.Join(err, err2) == nil {
					vals[i] = cmp.Compare(m, n)
				}
			}

			return cmp.Or(vals...)
		})

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
		if len(matched) > 0 {
			out = append(out, &ggml.Tensor{
				Name:     merges[i].name,
				Kind:     matched[0].Kind(),
				Shape:    append([]uint64{uint64(len(matched))}, matched[0].Shape()...),
				WriterTo: mergeGroup(matched),
			})
		}
	}

	return out, unmatched
}

// slicesSplitFunc splits a slice into two slices based on a predicate function.
func slicesSplitFunc[S ~[]E, E comparable](s S, fn func(e E) bool) (matched, unmatched S) {
	for _, e := range s {
		if fn(e) {
			matched = append(matched, e)
		} else {
			unmatched = append(unmatched, e)
		}
	}

	return matched, unmatched
}

type mergeGroup []Tensor

func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
	for _, t := range g {
		if _, err := t.WriteTo(w); err != nil {
			return 0, err
		}
	}

	return 0, nil
}