model.go 5.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
package qwen25vl

import (
	"bytes"
	"fmt"
	"image"
	"slices"
	"sync"

	"github.com/ollama/ollama/fs"
	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/model"
	"github.com/ollama/ollama/model/input"
)

type Model struct {
	model.Base
	model.BytePairEncoding

	*TextModel
	*VisionModel `gguf:"v,vision"`

	ImageProcessor
}

// Implement MultimodalProcessor interface
var _ model.MultimodalProcessor = (*Model)(nil)

func New(c fs.Config) (model.Model, error) {
	m := &Model{
		BytePairEncoding: model.NewBytePairEncoding(
			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
			&model.Vocabulary{
				Values: c.Strings("tokenizer.ggml.tokens"),
				Types:  c.Ints("tokenizer.ggml.token_type"),
				Merges: c.Strings("tokenizer.ggml.merges"),
				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
				AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
				AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
				EOT:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
				AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false),
			},
		),
		TextModel:      NewTextModel(c),
		VisionModel:    newVisionModel(c),
		ImageProcessor: newImageProcessor(c),
	}

	m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)

	return m, nil
}

func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) {
	image, _, err := image.Decode(bytes.NewReader(multimodalData))
	if err != nil {
		return nil, nil, err
	}

	f32s, grid, err := m.ImageProcessor.ProcessImage(image)
	if err != nil {
		return nil, nil, err
	}

	// Calculate tensor dimensions
	patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
		m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
	numPatches := grid.Temporal * grid.Height * grid.Width

	pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
	if err != nil {
		return nil, nil, fmt.Errorf("failed to create tensor from image: %w", err)
	}

	return pixelValues, grid, nil
}

func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
	if len(m.VisionModel.Layers) == 0 {
		return nil, model.ErrNoVisionModel
	}

	pixels, grid, err := m.PixelValues(ctx, multimodalData)
	if err != nil {
		return nil, err
	}

	visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
	return &chunks{Model: m, Tensor: visionOutputs}, nil
}

type chunks struct {
	*Model
	ml.Tensor

	dataOnce sync.Once
	data     []float32
}

type chunk struct {
	*chunks
	s, n int
}

func (r *chunk) floats() []float32 {
	r.dataOnce.Do(func() {
		temp := r.Backend().NewContext()
		defer temp.Close()
		temp.Forward(r.Tensor).Compute(r.Tensor)
		r.data = r.Floats()
	})

	return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)]
}

// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
	var result []input.Input

	var (
		imageToken       int32 = 151655
		visionStartToken int32 = 151652
		visionEndToken   int32 = 151653
	)

	nImg := 0
	for _, inp := range inputs {
		if inp.Multimodal == nil {
			// If not a multimodal input, add it to the result unchanged
			result = append(result, inp)
		} else {
			// Adding the 'Picture' prefix is a hack, at the time of writing there is no way to prefix
			// the image tokens with a prompt, so we add a prefix here
			nImg++
			pre, err := m.Encode(fmt.Sprintf(" Picture %d: ", nImg), true)
			if err != nil {
				return nil, fmt.Errorf("failed to encode image prompt: %w", err)
			}
			for i := range pre {
				result = append(result, input.Input{Token: pre[i]})
			}

			// This is an image token with multimodal data
			chunksData := inp.Multimodal.(*chunks)
			patchesPerChunk := chunksData.Dim(1)

			// First add the vision start token
			result = append(result, input.Input{Token: visionStartToken, SameBatch: patchesPerChunk + 2})

			// Add the image token with the multimodal tensor data at the first position
			// Create a chunk with proper s and n values
			result = append(result, input.Input{
				Token:          imageToken,
				Multimodal:     &chunk{chunks: chunksData, s: 0, n: patchesPerChunk},
				MultimodalHash: inp.MultimodalHash,
				SameBatch:      patchesPerChunk,
			})

			// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
			result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...)

			result = append(result, input.Input{Token: visionEndToken})
		}
	}

	return result, nil
}

func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
	positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
	if err != nil {
		return nil, err
	}

	outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
	if err != nil {
		return nil, err
	}

	return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
}

func init() {
	model.Register("qwen25vl", New)
}