image.go 3.03 KB
Newer Older
Jesse Gross's avatar
Jesse Gross committed
1
package llamarunner
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

import (
	"errors"
	"fmt"
	"hash/maphash"
	"log/slog"
	"sync"
	"time"

	"github.com/ollama/ollama/llama"
)

const imageCacheSize = 4

type ImageContext struct {
	// mu is required to be held when generating embeddings or accessing the cache
	mu sync.Mutex

20
	mtmd *llama.MtmdContext
21
22
23
24
25
26
27
28
29
30
31
32
33
34

	// cache of images to embeddings
	images    []imageCache
	imageHash maphash.Hash
}

func NewImageContext(llamaContext *llama.Context, modelPath string) (*ImageContext, error) {
	arch, err := llama.GetModelArch(modelPath)
	if err != nil {
		return nil, fmt.Errorf("unable to determine vision architecture: %w (%s)", err, modelPath)
	}

	var c ImageContext
	if arch == "clip" {
35
		c.mtmd, err = llama.NewMtmdContext(llamaContext, modelPath)
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
	} else {
		return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
	}

	if err != nil {
		return nil, err
	}

	c.images = make([]imageCache, imageCacheSize)

	return &c, nil
}

func (c *ImageContext) Free(modelPath string) {
	if c == nil {
		return
	}

54
55
	if c.mtmd != nil {
		c.mtmd.Free()
56
57
58
	}
}

59
func (c *ImageContext) MultimodalTokenize(llamaContext *llama.Context, data []byte) ([]llama.MtmdChunk, error) {
60
	if c == nil {
Jesse Gross's avatar
Jesse Gross committed
61
		return nil, nil
62
63
	}

64
65
66
67
	if len(data) <= 0 {
		return nil, errors.New("received zero length image")
	}

68
69
70
71
72
	hash := c.hashImage(data)

	c.mu.Lock()
	defer c.mu.Unlock()

73
	chunks, err := c.findImage(hash)
74
	if err != nil {
75
		if c.mtmd != nil {
76
			chunks, err = c.mtmd.MultimodalTokenize(llamaContext, data)
Jesse Gross's avatar
Jesse Gross committed
77
78
79
			if err != nil {
				return nil, err
			}
80
		} else {
Jesse Gross's avatar
Jesse Gross committed
81
			return nil, errors.New("received image but vision model not loaded")
82
83
		}

84
		c.addImage(hash, chunks)
85
86
	}

87
	return chunks, nil
88
89
}

90
91
92
93
94
95
96
97
98
func (c *ImageContext) BatchSize(configuredBatchSize int) int {
	// If images are not supported, we don't need to allocate embedding batches
	if c == nil {
		return 0
	}

	return configuredBatchSize
}

99
func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
100
	return llamaContext.Model().NEmbd()
101
102
}

103
104
type imageCache struct {
	key      uint64
105
	val      []llama.MtmdChunk
106
107
108
109
110
111
112
113
114
115
116
	lastUsed time.Time
}

func (c *ImageContext) hashImage(image []byte) uint64 {
	c.imageHash.Reset()
	_, _ = c.imageHash.Write(image)
	return c.imageHash.Sum64()
}

var errImageNotFound = errors.New("image not found in cache")

117
func (c *ImageContext) findImage(hash uint64) ([]llama.MtmdChunk, error) {
118
119
120
121
122
123
124
125
126
127
128
	for i := range c.images {
		if c.images[i].key == hash {
			slog.Debug("loading image embeddings from cache", "entry", i)
			c.images[i].lastUsed = time.Now()
			return c.images[i].val, nil
		}
	}

	return nil, errImageNotFound
}

129
func (c *ImageContext) addImage(hash uint64, embed []llama.MtmdChunk) {
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
	best := time.Now()
	var bestImage int

	for i := range c.images {
		if c.images[i].key == hash {
			bestImage = i
			break
		}

		if c.images[i].lastUsed.Compare(best) < 0 {
			best = c.images[i].lastUsed
			bestImage = i
		}
	}

	slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
	c.images[bestImage].key = hash
	c.images[bestImage].val = embed
	c.images[bestImage].lastUsed = time.Now()
}