multimodal.go 2.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
package ollamarunner

import (
	"errors"

	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/model/input"
)

// Tensors can't be used across multiple compute graphs. This is a problem
// if a single embedding is split across batches using views since all of
// the views will have the same source tensor. We also don't want to
// recompute the entire embedding for each batch.
//
// To avoid this, we compute all of the tensors for the embedding on the
// first use and then store the result in system memory. When we need
// additional tensors, we recreate them from the stored data.

// multimodalEntry represents the embeddings of a single object (such
// as an image).
type multimodalEntry struct {
	// mm is the original set of tensors created by EncodeMultimodal
	mm []input.Multimodal

	// data is the computed result of mm. Nil if not yet computed
	data [][]float32
}

// multimodalStore maps from an individual tensor (of which there
// may be many in a single multimodal object) to its parent embedding
type multimodalStore map[ml.Tensor]*multimodalEntry

func newMultimodalStore() multimodalStore {
	return make(multimodalStore)
}

// addMultimodal stores an embedding for later use in a compute graph
func (m multimodalStore) addMultimodal(embedding []input.Multimodal) {
	entry := &multimodalEntry{mm: embedding}

	for _, e := range embedding {
		if e.Tensor != nil {
			m[e.Tensor] = entry
		}
	}
}

// getMultimodal takes a source set of tensors (which may contain a whole or
// parts of one or more images) and returns the equivalent that can be used in
// the current context
51
func (m multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal, reserve bool) ([]input.Multimodal, error) {
52
53
54
55
	out := make([]input.Multimodal, len(in))
	for i := range out {
		if in[i].Tensor != nil {
			var err error
56
			out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor, reserve)
57
58
59
60
61
62
63
64
65
66
67
			if err != nil {
				return nil, err
			}
		}

		out[i].Data = in[i].Data
	}

	return out, nil
}

68
func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor, reserve bool) (ml.Tensor, error) {
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
	entry := m[in]

	if entry.data == nil {
		computeCtx := backend.NewContext()
		defer computeCtx.Close()

		var tensors []ml.Tensor
		for _, t := range entry.mm {
			if t.Tensor != nil {
				tensors = append(tensors, t.Tensor)
			}
		}

		if len(tensors) == 0 {
			return nil, nil
		}

86
		computeCtx.Forward(tensors...)
87
		entry.data = make([][]float32, len(entry.mm))
88
89
90
91
92
93
94
95
96
97

		if !reserve {
			computeCtx.Compute(tensors...)

			for i, t := range entry.mm {
				if t.Tensor != nil {
					entry.data[i] = t.Tensor.Floats()
				}
			}
		} else {
98
			computeCtx.Reserve()
99
100
101
102
103
		}
	}

	for i, t := range entry.mm {
		if in == t.Tensor {
104
			if !reserve {
105
				return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...), nil
106
107
108
			} else {
				return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil
			}
109
110
111
112
113
		}
	}

	return nil, errors.New("multimodal tensor not found")
}