tensor.go 2.98 KB
Newer Older
1
2
3
package convert

import (
Michael Yang's avatar
Michael Yang committed
4
	"cmp"
5
	"io"
6
	"iter"
7
	"path"
8
9
10
11
12
	"slices"
	"strings"

	"github.com/pdevine/tensor"
	"github.com/pdevine/tensor/native"
Michael Yang's avatar
Michael Yang committed
13
14

	"github.com/ollama/ollama/fs/ggml"
15
16
)

Michael Yang's avatar
Michael Yang committed
17
18
19
20
21
22
23
24
type split struct {
	*strings.Replacer
	dim int

	// fn is an optional function to apply to the tensor after slicing
	fn func(tensor.Tensor) (tensor.Tensor, error)
}

25
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
Michael Yang's avatar
Michael Yang committed
26
27
// is split evenly based on the number of replacers provided unless a specific count is given.
func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
28
	return func(yield func(*ggml.Tensor) bool) {
Michael Yang's avatar
Michael Yang committed
29
30
31
		var offset int
		for _, split := range splits {
			t := t.Clone()
32
			shape := slices.Clone(t.Shape())
Michael Yang's avatar
Michael Yang committed
33
			shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
34
35

			slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
Michael Yang's avatar
Michael Yang committed
36
37
			slice[dim] = tensor.S(offset, offset+int(shape[dim]))
			offset += int(shape[dim])
38

Michael Yang's avatar
Michael Yang committed
39
			t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
40
41
42
43
44
				dims := make([]int, len(shape))
				for i := range shape {
					dims[i] = int(shape[i])
				}

Michael Yang's avatar
Michael Yang committed
45
46
				var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
				tt, err := tt.Slice(slice...)
47
48
49
50
				if err != nil {
					return nil, err
				}

Michael Yang's avatar
Michael Yang committed
51
52
53
54
55
56
57
58
59
				tt = tensor.Materialize(tt)

				if split.fn != nil {
					tt, err = split.fn(tt)
					if err != nil {
						return nil, err
					}
				}

60
				// flatten tensor so it can be written as a vector
Michael Yang's avatar
Michael Yang committed
61
				if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
62
63
64
					return nil, err
				}

Michael Yang's avatar
Michael Yang committed
65
				return native.VectorF32(tt.(*tensor.Dense))
66
67
68
			})

			if !yield(&ggml.Tensor{
Michael Yang's avatar
Michael Yang committed
69
				Name:     split.Replace(t.Name()),
70
71
				Kind:     t.Kind(),
				Shape:    shape,
Michael Yang's avatar
Michael Yang committed
72
				WriterTo: t,
73
74
75
76
77
78
			}) {
				break
			}
		}
	}
}
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

type merge struct {
	pattern, name string
}

// mergeTensors merges tensors that match a given pattern into a single tensor.
func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []Tensor) {
	var matched []Tensor
	for i := range merges {
		matched, unmatched = slicesSplitFunc(unmatched, func(t Tensor) bool {
			matched, _ := path.Match(merges[i].pattern, t.Name())
			return matched
		})

		if len(matched) > 0 {
			out = append(out, &ggml.Tensor{
				Name:     merges[i].name,
				Kind:     matched[0].Kind(),
				Shape:    append([]uint64{uint64(len(matched))}, matched[0].Shape()...),
				WriterTo: mergeGroup(matched),
			})
		}
	}

	return out, unmatched
}

// slicesSplitFunc splits a slice into two slices based on a predicate function.
func slicesSplitFunc[S ~[]E, E comparable](s S, fn func(e E) bool) (matched, unmatched S) {
	for _, e := range s {
		if fn(e) {
			matched = append(matched, e)
		} else {
			unmatched = append(unmatched, e)
		}
	}

	return matched, unmatched
}

type mergeGroup []Tensor

func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
	for _, t := range g {
		if _, err := t.WriteTo(w); err != nil {
			return 0, err
		}
	}

	return 0, nil
}