generate.go 9.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
//go:build mlx

package main

import (
	"context"
	"fmt"
	"time"
	"unicode/utf8"

	"github.com/ollama/ollama/x/imagegen/cache"
	"github.com/ollama/ollama/x/imagegen/mlx"
	"github.com/ollama/ollama/x/imagegen/tokenizer"
)

// Dedicated stream for generation (like mlx-lm's generation_stream)
var generationStream *mlx.Stream

// utf8Streamer buffers decoded text and emits only complete UTF-8 characters.
// This handles cases where tokenizers output partial multi-byte sequences.
type utf8Streamer struct {
	buffer []byte
}

// Write adds decoded text to the buffer and returns complete UTF-8 characters.
func (s *utf8Streamer) Write(text string) string {
	s.buffer = append(s.buffer, text...)

	// Find the last position that ends with a complete UTF-8 character
	validLen := 0
	for i := 0; i < len(s.buffer); {
		r, size := utf8.DecodeRune(s.buffer[i:])
		if r == utf8.RuneError && size == 1 {
			// Invalid or incomplete UTF-8 sequence at this position
			// Check if it could be a valid start of a multi-byte sequence
			if len(s.buffer)-i < 4 {
				// Might be incomplete, keep it in buffer
				break
			}
			// Definitely invalid, skip this byte
			i++
			validLen = i
		} else {
			i += size
			validLen = i
		}
	}

	if validLen == 0 {
		return ""
	}

	result := string(s.buffer[:validLen])
	s.buffer = s.buffer[validLen:]
	return result
}

// Flush returns any remaining buffered bytes (may be incomplete UTF-8).
func (s *utf8Streamer) Flush() string {
	if len(s.buffer) == 0 {
		return ""
	}
	result := string(s.buffer)
	s.buffer = nil
	return result
}

func init() {
	generationStream = mlx.NewStream()
}

// withStream runs fn with the generation stream as default
func withStream(fn func()) {
	orig := mlx.GetDefaultStream()
	mlx.SetDefaultStream(generationStream)
	fn()
	mlx.SetDefaultStream(orig)
}

type Model interface {
	Tokenizer() *tokenizer.Tokenizer
	VocabSize() int32
	NewCache(maxSeqLen int32) []cache.Cache
	Forward(input *mlx.Array, caches []cache.Cache) *mlx.Array
}

// ChatModel is an optional interface for models that support chat formatting
type ChatModel interface {
	FormatPrompt(prompt string) string
}

// MultimodalModel is for models that support image input
type MultimodalModel interface {
	Model
	FormatPromptWithImage(prompt string) string
	ExpandImageTokens(tokens []int32) []int32
	ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array
	ImageSize() int32 // Returns expected image size for preprocessing
}

// ImageLoader loads and preprocesses an image for multimodal models
// Returns nil if path is empty
type ImageLoader func(path string, imageSize int32) (*mlx.Array, error)

type input struct {
	Prompt       string
	Image        *mlx.Array // Optional preprocessed image for multimodal models
	MaxTokens    int
	Temperature  float32
	TopP         float32
	TopK         int
	WiredLimitGB int // Metal wired memory limit in GB (default 32)
}

type output struct {
	Text          string
	Done          bool
	PrefillTokSec float64
	GenTokSec     float64
}

// Decoder wraps model + cache for autoregressive generation.
type Decoder struct {
	model         Model
	caches        []cache.Cache
	vocabSize     int32
	temp          float32
	topK          int
	topP          float32
	token         *mlx.Array   // Current token (kept across pools)
	oldCacheState []*mlx.Array // Preallocated slice for old cache state
	image         *mlx.Array   // Optional image for multimodal prefill
}

func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
	caches := m.NewCache(0)
	return &Decoder{
		model:         m,
		caches:        caches,
		vocabSize:     m.VocabSize(),
		temp:          temp,
		topK:          topK,
		topP:          topP,
		oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
	}
}

// SetImage sets the image for multimodal prefill (call before prefill)
func (d *Decoder) SetImage(img *mlx.Array) {
	d.image = img
}

func (d *Decoder) prefill(inputIDs []int32) int {
	processed := 0

	// Track old cache state to free after each chunk
	var oldCacheState []*mlx.Array

	// For multimodal models with an image, we need to process all tokens together
	// in the first forward pass so the image embeddings can be inserted properly.
	// Skip chunking for multimodal prefill.
	isMultimodal := d.image != nil

	// Process all-but-1 tokens in chunks, eval cache state for memory management
	// Skip chunking for multimodal - process everything in the final step
	if !isMultimodal {
		for len(inputIDs) > 1 {
			chunkSize := min(2048, len(inputIDs)-1)
			if chunkSize <= 0 {
				break
			}
			chunk := inputIDs[:chunkSize]

			// Save old cache state before forward
			oldCacheState = oldCacheState[:0]
			for _, c := range d.caches {
				oldCacheState = append(oldCacheState, c.State()...)
			}

			var cacheState []*mlx.Array
			withStream(func() {
				x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
				d.model.Forward(x, d.caches)
				for _, c := range d.caches {
					cacheState = append(cacheState, c.State()...)
				}
			})
			mlx.Eval(cacheState...)

			// Free old cache state
			for _, arr := range oldCacheState {
				if arr != nil {
					arr.Free()
				}
			}

			inputIDs = inputIDs[chunkSize:]
			processed += chunkSize
		}
	}

	// Save old cache state before final step
	oldCacheState = oldCacheState[:0]
	for _, c := range d.caches {
		oldCacheState = append(oldCacheState, c.State()...)
	}

	// Final token + sampling (or all tokens for multimodal)
	withStream(func() {
		x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
		mlx.Eval(x) // Materialize before any other evals

		var logits *mlx.Array
		// Use ForwardWithImage if we have an image and model supports it
		if d.image != nil {
			if mm, ok := d.model.(MultimodalModel); ok {
				logits = mm.ForwardWithImage(x, d.image, d.caches)
				d.image = nil // Only use image for first forward
			} else {
				logits = d.model.Forward(x, d.caches)
			}
		} else {
			logits = d.model.Forward(x, d.caches)
		}
		d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
	})
	// Keep cache state (token auto-kept by AsyncEval)
	for _, c := range d.caches {
		mlx.Keep(c.State()...)
	}
	mlx.AsyncEval(d.token)

	// Free old cache state from before final step
	for _, arr := range oldCacheState {
		if arr != nil {
			arr.Free()
		}
	}

	mlx.ClearCache()

	return processed + len(inputIDs)
}

func (d *Decoder) step() int32 {
	prevToken := d.token

	// Save old cache state (reuse preallocated slice)
	d.oldCacheState = d.oldCacheState[:0]
	for _, c := range d.caches {
		d.oldCacheState = append(d.oldCacheState, c.State()...)
	}

	withStream(func() {
		logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
		d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
	})
	// Keep token and new cache state so they survive cleanup
	mlx.Keep(d.token)
	for _, c := range d.caches {
		mlx.Keep(c.State()...)
	}
	mlx.AsyncEval(d.token)

	// Sync on previous token (GPU already working on next step)
	val := prevToken.ItemInt32()

	// Free old token and old cache state
	prevToken.Free()
	for _, arr := range d.oldCacheState {
		arr.Free()
	}
	return val
}

func generate(ctx context.Context, m Model, in input, cb func(output)) error {
	mlx.EnableCompile()
	wiredLimit := in.WiredLimitGB
	if wiredLimit <= 0 {
		wiredLimit = 32 // default 32GB
	}
	mlx.MetalSetWiredLimit(uint64(wiredLimit) << 30)

	temp := in.Temperature
	if temp < 0 {
		temp = 0.7
	}

	tok := m.Tokenizer()
	dec := NewDecoder(m, temp, in.TopK, in.TopP)

	// Apply chat template - use image template if we have an image
	prompt := in.Prompt
	var tokens []int32
	if mm, ok := m.(MultimodalModel); ok && in.Image != nil {
		prompt = mm.FormatPromptWithImage(prompt)
		tokens = tok.Encode(prompt, true)
		tokens = mm.ExpandImageTokens(tokens) // Expand <start_of_image> to 256 image tokens
		dec.SetImage(in.Image)
	} else if cm, ok := m.(ChatModel); ok {
		prompt = cm.FormatPrompt(prompt)
		tokens = tok.Encode(prompt, true)
	} else {
		tokens = tok.Encode(prompt, true)
	}

	prefillStart := time.Now()
	prefillTokens := dec.prefill(tokens)
	// Prefill measurement should include time to first token (like mlx-lm)
	// Step() waits for prefill to complete and returns first token
	firstToken := dec.step()
	prefillTokSec := float64(prefillTokens) / time.Since(prefillStart).Seconds()

	genStart := time.Now()
	maxTokens := max(in.MaxTokens, 100)
	var genTokens int

	// UTF-8 streamer to handle partial multi-byte characters
	streamer := &utf8Streamer{}

	// Handle first token
	genTokens++
	if tok.IsEOS(firstToken) {
		cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: 0})
		return nil
	}
	if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
		cb(output{Text: text})
	}

	for n := 1; n < maxTokens; n++ {
		if ctx.Err() != nil {
			return ctx.Err()
		}
		token := dec.step()
		genTokens++

		if tok.IsEOS(token) {
			break
		}
		if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
			cb(output{Text: text})
		}

		if n%256 == 0 {
			mlx.ClearCache()
		}
	}

	// Flush any remaining buffered bytes
	if text := streamer.Flush(); text != "" {
		cb(output{Text: text})
	}

	fmt.Printf("\nPeak memory: %.2fGB\n", float64(mlx.MetalGetPeakMemory())/(1<<30))
	cb(output{Done: true, PrefillTokSec: prefillTokSec,
		GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
	return nil
}