runner.go 39.6 KB
Newer Older
Jesse Gross's avatar
Jesse Gross committed
1
package ollamarunner
2
3

import (
4
	"bytes"
5
6
7
8
9
	"context"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
10
	"hash/maphash"
11
	"image"
12
13
14
15
16
	"log"
	"log/slog"
	"net"
	"net/http"
	"os"
Jesse Gross's avatar
Jesse Gross committed
17
	"reflect"
18
19
20
21
22
23
	"regexp"
	"runtime"
	"strconv"
	"strings"
	"sync"
	"time"
24
	"unicode/utf8"
25

26
	"golang.org/x/image/bmp"
27
28
	"golang.org/x/sync/semaphore"

29
	"github.com/ollama/ollama/api"
30
	"github.com/ollama/ollama/envconfig"
31
	"github.com/ollama/ollama/fs/ggml"
32
	"github.com/ollama/ollama/llm"
33
	"github.com/ollama/ollama/logutil"
34
	"github.com/ollama/ollama/ml"
Michael Yang's avatar
Michael Yang committed
35
	"github.com/ollama/ollama/ml/nn/pooling"
Jesse Gross's avatar
Jesse Gross committed
36
	"github.com/ollama/ollama/model"
37
	"github.com/ollama/ollama/model/input"
Jesse Gross's avatar
Jesse Gross committed
38
39
40
41
	"github.com/ollama/ollama/runner/common"
	"github.com/ollama/ollama/sample"

	_ "github.com/ollama/ollama/model/models"
42
43
)

44
45
46
47
48
49
// response contains a piece of generated text along with optional logprobs
type response struct {
	content  string
	logprobs []llm.Logprob
}

50
type Sequence struct {
51
	// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
52
	// multimodal embeddings
53
	ctxs []ml.Context
54

55
56
57
	// mmStore holds multimodal embeddings to mange memory and enable splitting across batches
	mmStore multimodalStore

58
59
60
61
	// batch index
	iBatch int

	// prompt inputs left to evaluate
62
	inputs []*input.Input
63

Jesse Gross's avatar
Jesse Gross committed
64
	// inputs that have been added to a batch but not yet submitted to Forward
65
	pendingInputs []*input.Input
66

67
68
69
	// tokens that have been generated but not returned yet (e.g. for stop sequences)
	pendingResponses []string

70
71
72
	// logprobs for tokens that haven't been returned yet
	pendingLogprobs []llm.Logprob

73
74
75
76
	// input cache being used by this sequence
	cache *InputCacheSlot

	// channel to send responses over
77
	responses chan response
78
79
80
81
82
83
84

	// channel to stop decoding (such as if the remote connection is closed)
	quit chan bool

	// number of tokens to predict
	numPredict int

85
86
	// sampler with transforms to run on generated logits
	sampler sample.Sampler
87
88
89
90
91
92
93
94

	// channel to send back the embedding if embedding only
	embedding chan []float32

	// stop sequences
	stop []string

	// number of inputs to keep at the beginning when shifting context window
Jesse Gross's avatar
Jesse Gross committed
95
	numKeep int32
96
97
98
99

	// true if an embedding are to be returned instead of text generation
	embeddingOnly bool

100
101
102
	// shift if context window is exceeded
	shift bool

103
	doneReason llm.DoneReason
104

105
106
107
108
	// logprobs configuration
	logprobs    bool
	topLogprobs int

109
	// Metrics
110
111
112
113
114
	startedAt, lastUpdatedAt time.Time
	processingDuration       time.Duration
	samplingDuration         time.Duration
	numPredicted             int
	numPromptInputs          int
115
116
117
}

type NewSequenceParams struct {
118
119
120
121
122
123
124
125
126
	numPredict  int
	stop        []string
	numKeep     int32
	sampler     sample.Sampler
	embedding   bool
	shift       bool
	truncate    bool
	logprobs    bool
	topLogprobs int
127
128
}

129
130
var errorInputTooLong = errors.New("the input length exceeds the context length")

131
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
132
133
	s.ready.Wait()

134
	inputs, ctxs, mmStore, err := s.inputs(prompt, images)
135
136
137
138
139
140
141
	if err != nil {
		return nil, fmt.Errorf("failed to process inputs: %w", err)
	} else if len(inputs) == 0 {
		return nil, errors.New("no input provided")
	}

	if params.numKeep < 0 {
Jesse Gross's avatar
Jesse Gross committed
142
		params.numKeep = int32(len(inputs))
143
144
	}

145
146
147
	// Ensure that at least 1 input can be discarded during shift
	params.numKeep = min(params.numKeep, s.cache.numCtx-1)

Jesse Gross's avatar
Jesse Gross committed
148
	if int32(len(inputs)) > s.cache.numCtx {
149
150
151
152
		if !params.truncate {
			return nil, errorInputTooLong
		}

153
154
		discard := int32(len(inputs)) - s.cache.numCtx

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
		promptStart := params.numKeep + discard

		// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
		sameBatch := 0
		for i, inp := range inputs {
			if sameBatch > 0 {
				sameBatch--

				if promptStart == int32(i) {
					promptStart++
				}
			} else if promptStart == int32(i) {
				break
			}

			if inp.SameBatch != 0 {
				if int32(i) < params.numKeep {
					return nil, fmt.Errorf("SameBatch may not be specified within numKeep (index: %v numKeep: %v SameBatch: %v)", i, params.numKeep, inp.SameBatch)
				}

				sameBatch = inp.SameBatch
			}
		}

		if promptStart >= int32(len(inputs)) {
			return nil, errors.New("entire prompt removed by truncation")
		}

183
		newInputs := inputs[:params.numKeep]
184
		newInputs = append(newInputs, inputs[promptStart:]...)
185
186

		slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
187
		inputs = newInputs
188
189
	}

Jesse Gross's avatar
Jesse Gross committed
190
	// TODO(jessegross): Ingest cached history for grammar
191
192

	return &Sequence{
193
194
195
196
197
198
		ctxs:             ctxs,
		mmStore:          mmStore,
		inputs:           inputs,
		numPromptInputs:  len(inputs),
		numPredict:       params.numPredict,
		pendingResponses: make([]string, 0),
199
		responses:        make(chan response, 100),
200
201
202
203
204
205
		quit:             make(chan bool, 1),
		embedding:        make(chan []float32, 1),
		sampler:          params.sampler,
		embeddingOnly:    params.embedding,
		stop:             params.stop,
		numKeep:          params.numKeep,
206
		shift:            params.shift,
207
208
		logprobs:         params.logprobs,
		topLogprobs:      params.topLogprobs,
209
210
211
	}, nil
}

212
213
214
215
216
217
218
219
220
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
	decoder := func(tokenID int) string {
		text, _ := textProcessor.Decode([]int32{int32(tokenID)})
		return text
	}
	return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
}

221
222
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
Jesse Gross's avatar
Jesse Gross committed
223
// decoding images
224
225
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) {
	var inputs []*input.Input
226
	var ctxs []ml.Context
227
	var mmStore multimodalStore
228

229
230
231
	var parts []string
	var matches [][]string

232
	multimodalProcessor, visionModel := s.model.(model.MultimodalProcessor)
233

234
235
236
237
	if visionModel {
		re := regexp.MustCompile(`\[img-(\d+)\]`)
		parts = re.Split(prompt, -1)
		matches = re.FindAllStringSubmatch(prompt, -1)
238
		mmStore = newMultimodalStore()
239
240
241
242
	} else {
		parts = []string{prompt}
	}

243
244
	for i, part := range parts {
		// text - tokenize
245
		tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
246
		if err != nil {
247
			return nil, nil, nil, err
248
		}
249

250
		for _, t := range tokens {
251
			inputs = append(inputs, &input.Input{Token: t})
252
253
		}

Jesse Gross's avatar
Jesse Gross committed
254
		// image - decode and store
255
256
257
258
259
260
261
262
263
264
265
266
		if i < len(matches) {
			n, _ := strconv.Atoi(matches[i][1])

			imageIndex := -1
			for j := range images {
				if images[j].ID == n {
					imageIndex = j
					break
				}
			}

			if imageIndex < 0 {
267
				return nil, nil, nil, fmt.Errorf("invalid image index: %d", n)
268
269
			}

270
			ctx := s.model.Backend().NewContext()
271
272
			runtime.SetFinalizer(ctx, func(c ml.Context) { c.Close() })
			ctxs = append(ctxs, ctx)
273
			imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
Jesse Gross's avatar
Jesse Gross committed
274
			if err != nil {
275
				return nil, nil, nil, err
Jesse Gross's avatar
Jesse Gross committed
276
277
			}

278
279
280
281
			s.multimodalHash.Reset()
			_, _ = s.multimodalHash.Write(images[imageIndex].Data)
			imageHash := s.multimodalHash.Sum64()

282
283
			mmStore.addMultimodal(imageEmbeddings)

284
			inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
285
286
287
		}
	}

288
	if visionModel {
289
		var err error
290
		inputs, err = multimodalProcessor.PostTokenize(inputs)
291
		if err != nil {
292
			return nil, nil, nil, err
293
294
295
		}
	}

296
	return inputs, ctxs, mmStore, nil
297
298
}

299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
type batchState struct {
	// id provides a counter for trace logging batches
	id int

	// ctx holds the backend context used for this batch
	ctx ml.Context

	// modelOutput holds the outputs from this batch
	modelOutput ml.Tensor

	// batchInputs holds the input token pointers which may start as
	// placeholders later filled in before calling ctx.Compute
	batchInputs []*input.Input

	// batch contains the inputs for a model forward pass
	batch input.Batch

	// full set of seqs at the time this batch was initiated
	seqs []*Sequence

	// Signaled when this batches inputs are ready and compute can proceed
	inputsReadyCh chan struct{}

	// Signaling when Compute is about to begin on this batch, and
	// seqs have been updated to prepare for the next batch
	computeStartedCh chan struct{}

	// Signaled when this batches outputs are complete and the next batch can proceed
	outputsReadyCh chan struct{}
}

330
type Server struct {
Jesse Gross's avatar
Jesse Gross committed
331
332
333
334
335
336
337
338
339
340
	// modelPath is the location of the model to be loaded
	modelPath string

	// loadMu prevents more than one load attempt from occurring at a time
	loadMu sync.Mutex

	// lastLoad is the load request from the previous load attempt. Used to
	// detect if we can reuse an existing memory allocation.
	lastLoad llm.LoadRequest

341
342
343
344
345
	// is the server ready to process requests?
	// protects access to model and image
	ready sync.WaitGroup

	// loaded model
Jesse Gross's avatar
Jesse Gross committed
346
	model model.Model
347

348
	// status for external health reporting - loading, ready to serve, etc.
349
	status llm.ServerStatus
350
351
352
353
354
355
356
357

	// current progress on loading the model
	progress float32

	// number of simultaneous requests to handle
	parallel int

	// maximum number of elements in a batch (per sequence)
358
	// TODO (jmorganca): make this n_batch
359
360
	batchSize int

361
362
363
	// Simple counter used only for trace logging batches
	batchID int

364
365
366
367
368
369
370
371
	// protects access to everything below this line
	// this is context state needed for decoding
	mu sync.Mutex

	// indicates that data is ready for processing
	cond *sync.Cond

	// the list of simultaneous sequences being evaluated
372
373
	seqs []*Sequence

374
375
376
377
	// seqs can have a maximum of parallel entries, which
	// is enfoced by seqSem
	seqsSem *semaphore.Weighted

378
379
380
	// KV cache
	cache *InputCache

381
382
383
	// next sequence for prompt processing to avoid starvation
	nextSeq int

384
385
386
	// multimodalHash generates hashes for comparing equality
	// of non-text data
	multimodalHash maphash.Hash
387
388
389
390
391
392
393
394
395
396
397
398
}

func (s *Server) allNil() bool {
	for _, item := range s.seqs {
		if item != nil {
			return false
		}
	}
	return true
}

func flushPending(seq *Sequence) bool {
399
	joined := strings.Join(seq.pendingResponses, "")
400
	logprobs := seq.pendingLogprobs
401
	seq.pendingResponses = []string{}
402
	seq.pendingLogprobs = []llm.Logprob{}
403
404
405
406
407
408
409
410
411

	// Check if there are any partial UTF-8 characters remaining.
	// We already check and queue as we are generating but some may
	// still make it here:
	// - Sequence is ending, e.g. generation limit has been hit
	// - Invalid characters in the middle of a string
	// This is a stricter check to ensure we never output invalid Unicode.
	for !utf8.ValidString(joined) {
		joined = joined[:len(joined)-1]
412
413
	}

414
415
416
417
418
	if len(joined) == 0 {
		return true
	}

	select {
419
	case seq.responses <- response{content: joined, logprobs: logprobs}:
420
421
422
423
		return true
	case <-seq.quit:
		return false
	}
424
425
}

426
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
427
428
429
430
431
432
433
434
	seq := s.seqs[seqIndex]

	flushPending(seq)
	seq.doneReason = reason
	close(seq.responses)
	close(seq.embedding)
	seq.cache.InUse = false
	s.seqs[seqIndex] = nil
435
	s.seqsSem.Release(1)
436
437
}

438
439
// track batch state between forwardBatch, computeBatch and predictForwardBatch

440
441
442
func (s *Server) run(ctx context.Context) {
	s.ready.Wait()

Michael Yang's avatar
Michael Yang committed
443
	supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone
Michael Yang's avatar
Michael Yang committed
444

445
	var previousBatch batchState
446
447
448
449
450
	for {
		select {
		case <-ctx.Done():
			return
		default:
451
			var err error
452
			nextBatch, err := s.forwardBatch(previousBatch)
453
454
455
			if err != nil {
				panic(err)
			}
Michael Yang's avatar
Michael Yang committed
456
457

			if supportsAsync {
458
				go s.computeBatch(nextBatch)
Michael Yang's avatar
Michael Yang committed
459
			} else {
460
				s.computeBatch(nextBatch)
Michael Yang's avatar
Michael Yang committed
461
			}
462
463

			previousBatch = nextBatch
464
465
466
467
		}
	}
}

468
469
470
471
472
473
// forwardBatch will calculate a batch.
func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, err error) {
	// If we have a pending batch still processing, wait until Compute has started
	// before setting up the next batch so the seqs inputs are ready to receive their
	// token values and we get the correct input pointers for the batchInputs
	if pendingBatch.ctx != nil {
Michael Yang's avatar
Michael Yang committed
474
		logutil.Trace("forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
475
		<-pendingBatch.computeStartedCh
Michael Yang's avatar
Michael Yang committed
476
		logutil.Trace("forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
477
478
		nextBatch.inputsReadyCh = pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
	} else {
Michael Yang's avatar
Michael Yang committed
479
		logutil.Trace("forwardBatch no pending batch detected", "batchID", s.batchID)
480
481
482
483
484
		// No pendingBatch, so the inputs will be ready in the seqs immediately
		nextBatch.inputsReadyCh = make(chan struct{}, 1)
		nextBatch.inputsReadyCh <- struct{}{}
	}

485
486
487
488
489
490
	s.mu.Lock()
	for s.allNil() {
		s.cond.Wait() // Wait until an item is added
	}
	defer s.mu.Unlock()

491
492
493
494
495
496
497
498
499
500
501
	nextBatch.ctx = s.model.Backend().NewContext()
	defer func() {
		if err != nil {
			nextBatch.ctx.Close()
			nextBatch.ctx = nil
		}
	}()
	nextBatch.id = s.batchID
	nextBatch.seqs = append([]*Sequence{}, s.seqs...)
	nextBatch.computeStartedCh = make(chan struct{}, 1)
	nextBatch.outputsReadyCh = make(chan struct{}, 1)
502

503
504
	// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
	var batchInputs []*input.Input
505
	var batchOutputs []int32
Jesse Gross's avatar
Jesse Gross committed
506
	var batch input.Batch
507

508
509
510
511
512
	resumeSeq := -1
	seqIdx := s.nextSeq - 1
	for range s.seqs {
		seqIdx = (seqIdx + 1) % len(s.seqs)
		seq := s.seqs[seqIdx]
513
514
515
516
517
		if seq == nil {
			continue
		}

		// if past the num predict limit
518
		if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
519
			s.removeSequence(seqIdx, llm.DoneReasonLength)
520
			nextBatch.seqs[seqIdx] = nil
521
522
523
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
524
525
		if !s.cache.enabled {
			seq.inputs = append(seq.cache.Inputs, seq.inputs...)
526
			seq.cache.Inputs = []*input.Input{}
Jesse Gross's avatar
Jesse Gross committed
527
528
		}

529
530
		batchSize := s.batchSize

531
		for i, inp := range seq.inputs {
532
533
			// If we are required to put following inputs into a single batch then extend the
			// batch size. Since we are only extending the size the minimum amount possible, this
534
			// will cause a break if we have existing inputs.
535
536
537
538
539
			minBatch := 1 + inp.SameBatch
			if minBatch > batchSize {
				batchSize = minBatch
			}

540
541
542
543
544
545
546
547
			// Stop if the required batch would put us over the total batch size (including tokens
			// added by other sequences). If we haven't been able to add anything yet then pick up
			// here again for the next batch to avoid starvation, though we can opportunistically
			// check if other sequences can still squeeze something in.
			if len(batchInputs)+minBatch > batchSize {
				if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
					resumeSeq = seqIdx
				}
548
549
				break
			}
Jesse Gross's avatar
Jesse Gross committed
550

551
552
553
554
555
556
557
558
			// If the sum of our working set (already processed tokens, tokens we added to this
			// batch, required following tokens) exceeds the context size, then trigger a shift
			// now so we don't have to do one later when we can't break the batch.
			if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
				if len(seq.pendingInputs) != 0 {
					break
				}

559
560
561
562
563
564
				if !seq.shift {
					s.removeSequence(seqIdx, llm.DoneReasonLength)
					nextBatch.seqs[seqIdx] = nil
					break
				}

565
				err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
566
				if err != nil {
567
568
569
570
571
					var reprocess *ErrReprocessInputs
					if errors.As(err, &reprocess) {
						// Prepend these inputs to the sequence's inputs queue for reprocessing
						seq.inputs = append(reprocess.Inputs, seq.inputs...)
						// Skip this sequence but continue processing the rest
572
573
						nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
						err = nil
574
575
						continue
					} else {
576
						return
577
					}
578
579
580
				}
			}

581
			batchInputs = append(batchInputs, seq.inputs[i])
582
			if inp.Multimodal != nil {
583
584
				var mm []input.Multimodal
				mm, err = seq.mmStore.getMultimodal(s.model.Backend(), nextBatch.ctx, inp.Multimodal, false)
585
				if err != nil {
586
					return
587
588
				}
				batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
589
590
			}

Jesse Gross's avatar
Jesse Gross committed
591
592
			batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
			batch.Sequences = append(batch.Sequences, seq.cache.Id)
Jesse Gross's avatar
Jesse Gross committed
593

594
595
596
			seq.iBatch = len(batchOutputs)
			if i+1 == len(seq.inputs) || seq.embeddingOnly {
				batchOutputs = append(batchOutputs, int32(len(batchInputs)-1))
Jesse Gross's avatar
Jesse Gross committed
597
			}
Michael Yang's avatar
Michael Yang committed
598
			logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
599
			seq.pendingInputs = append(seq.pendingInputs, inp)
600
		}
601
602

		seq.inputs = seq.inputs[len(seq.pendingInputs):]
603
604
	}

605
606
607
608
609
610
611
	startedAt := time.Now()
	for i := range nextBatch.seqs {
		if nextBatch.seqs[i] != nil && nextBatch.seqs[i].startedAt.IsZero() {
			nextBatch.seqs[i].startedAt = startedAt
		}
	}

612
613
614
615
616
617
	if resumeSeq != -1 {
		s.nextSeq = resumeSeq
	} else {
		s.nextSeq = seqIdx + 1
	}

618
	if len(batchInputs) == 0 {
Michael Yang's avatar
Michael Yang committed
619
		logutil.Trace("forwardBatch no batchInputs, going idle", "batchID", s.batchID)
620
621
622
		nextBatch.ctx.Close()
		nextBatch.ctx = nil
		return
623
	}
624
	s.batchID++
625

626
627
	// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
	batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
Michael Yang's avatar
Michael Yang committed
628
	batch.Outputs = nextBatch.ctx.Input().FromInts(batchOutputs, len(batchOutputs))
629
	nextBatch.ctx.SetBatchSize(len(batchInputs))
630
	nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
631
	if err != nil {
632
633
		err = fmt.Errorf("failed to build graph: %w", err)
		return
634
	}
635
636
	nextBatch.batchInputs = batchInputs
	nextBatch.batch = batch
637

638
639
640
641
642
643
644
645
646
647
648
649
	return
}

// Async processing of the next batch
func (s *Server) computeBatch(activeBatch batchState) {
	if activeBatch.ctx == nil {
		// Nothing to compute
		return
	}
	defer activeBatch.ctx.Close()

	// Wait until inputs are ready
Michael Yang's avatar
Michael Yang committed
650
	logutil.Trace("computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
651
	<-activeBatch.inputsReadyCh
Michael Yang's avatar
Michael Yang committed
652
	logutil.Trace("computeBatch: inputs are ready", "batchID", activeBatch.id)
653

654
655
656
	// Once we complete, signal the next batch of inputs are ready
	// This will unblock the next computeBatch, or forwardBatch if new seqs come in
	defer func() {
Michael Yang's avatar
Michael Yang committed
657
		logutil.Trace("computeBatch: outputs are ready", "batchID", activeBatch.id)
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
		activeBatch.outputsReadyCh <- struct{}{}
	}()

	s.mu.Lock()

	// Gather the actual input token values now that they're ready
	batchInputs := make([]int32, len(activeBatch.batchInputs))
	for i := range batchInputs {
		batchInputs[i] = activeBatch.batchInputs[i].Token
	}

	// Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
	// so that forwardBatch can build a batchInputs set which will eventually contain the actual
	// decoded tokens.
	nextBatchTokens := make([]*input.Input, len(s.seqs))
	iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
674
	for i, seq := range s.seqs {
675
		iBatches[i] = -1
676
677
678
		if seq == nil {
			continue
		}
679
680
681
682
		// Skip over any newly added or skipped sequences
		if activeBatch.seqs[i] == nil {
			continue
		}
683

684
685
686
		// Detect if the sequence we're processing has already been completed and replaced
		// with a new sequence
		if seq != activeBatch.seqs[i] {
Michael Yang's avatar
Michael Yang committed
687
			logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
688
689
690
691
692
693
694
695
696
697
698
			continue
		}

		// Pending inputs will actually be in the cache after we call Compute.
		// However, we have already resolved any placeholder tokens.
		//
		// It's possible for incoming sequences to look at the values that we've
		// added to the cache here and start relying on them before we've done
		// the computation. This is OK as long as we ensure that this batch's
		// computation happens before any future batch's and we never fail
		// (unless we take down the whole runner).
699
700
		if len(seq.pendingInputs) > 0 {
			seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
701
			seq.pendingInputs = []*input.Input{}
702
703
		}

704
705
		// don't sample prompt processing
		if len(seq.inputs) != 0 {
Jesse Gross's avatar
Jesse Gross committed
706
			if !s.cache.enabled {
Michael Yang's avatar
Michael Yang committed
707
				panic("caching disabled but unable to fit entire input in a batch")
Jesse Gross's avatar
Jesse Gross committed
708
			}
709
710
711
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
712
		seq.numPredicted++
713
714
715
716
717
718
719
720
721
		nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
		seq.inputs = []*input.Input{nextToken}
		nextBatchTokens[i] = nextToken
		iBatches[i] = seq.iBatch
	}

	// At this point the seqs are ready for forwardBatch to move forward so unblock
	s.mu.Unlock()

Michael Yang's avatar
Michael Yang committed
722
	activeBatch.batch.Inputs.FromInts(batchInputs)
723
724
	activeBatch.ctx.ComputeWithNotify(
		func() {
Michael Yang's avatar
Michael Yang committed
725
			logutil.Trace("computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
726
727
728
			activeBatch.computeStartedCh <- struct{}{}
		},
		activeBatch.modelOutput)
Michael Yang's avatar
Michael Yang committed
729
730

	outputs := activeBatch.modelOutput.Floats()
731
	t := time.Now()
732

Michael Yang's avatar
Michael Yang committed
733
	logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)
734
735
736
737

	s.mu.Lock()
	defer s.mu.Unlock()

Michael Yang's avatar
Michael Yang committed
738
	logutil.Trace("computeBatch: decoding", "batchID", activeBatch.id)
739
740
741
742
743
	for i, seq := range s.seqs {
		if seq == nil || nextBatchTokens[i] == nil {
			continue
		}

744
		seq.lastUpdatedAt = t
Jesse Gross's avatar
Jesse Gross committed
745
		if seq.numPredicted == 1 {
746
747
			seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
			seq.startedAt = seq.lastUpdatedAt
748
749
750
751
		}

		// if done processing the prompt, generate an embedding and return
		if seq.embeddingOnly {
Michael Yang's avatar
Michael Yang committed
752
			seq.embedding <- outputs
753
			s.removeSequence(i, llm.DoneReasonStop)
754
			continue
755
756
757
		}

		// sample a token
758
759
		vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0)
		logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
760
761
		logits := outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]
		token, err := seq.sampler.Sample(logits)
Jesse Gross's avatar
Jesse Gross committed
762
		if err != nil {
Michael Yang's avatar
Michael Yang committed
763
			panic("failed to sample token")
Jesse Gross's avatar
Jesse Gross committed
764
		}
765

766
767
		nextBatchTokens[i].Token = token

768
		// if it's an end of sequence token, break
Jesse Gross's avatar
Jesse Gross committed
769
		if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
770
771
772
			// TODO (jmorganca): we should send this back
			// as it's important for the /api/generate context
			// seq.responses <- piece
Michael Yang's avatar
Michael Yang committed
773
			logutil.Trace("computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
774
			s.removeSequence(i, llm.DoneReasonStop)
775
776
777
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
778
779
		piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
		if err != nil {
Michael Yang's avatar
Michael Yang committed
780
			panic("failed to decode token")
Jesse Gross's avatar
Jesse Gross committed
781
782
		}

783
784
785
786
787
788
		// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
		if seq.logprobs {
			logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
			seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
		}

789
790
791
		seq.pendingResponses = append(seq.pendingResponses, piece)
		sequence := strings.Join(seq.pendingResponses, "")

Jesse Gross's avatar
Jesse Gross committed
792
		if ok, stop := common.FindStop(sequence, seq.stop); ok {
793
794
795
796
			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)

			var tokenTruncated bool
			origLen := len(seq.pendingResponses)
Jesse Gross's avatar
Jesse Gross committed
797
			seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop)
798
799
			newLen := len(seq.pendingResponses)

800
801
802
803
804
805
806
807
808
809
810
			// Truncate logprobs to match the truncated responses
			if seq.logprobs {
				origLogprobsLen := len(seq.pendingLogprobs)
				numTokensRemoved := origLen - newLen
				newLogprobsLen := origLogprobsLen - numTokensRemoved
				if newLogprobsLen < 0 {
					newLogprobsLen = 0
				}
				seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen]
			}

811
812
813
814
815
816
817
818
819
820
821
822
			// Update the cache based on the tokens that will be returned:
			// - We have 1 token more than is currently in the cache because
			// the last one generated wasn't submitted to Decode
			// - Remove any stop sequences that we stripped out
			// - If truncateStop removed a portion of a token, drop that
			// - As defense-in-depth, if truncatedToken didn't find a stop token
			// remove the extra one that we added to the cache len
			tokenLen := len(seq.cache.Inputs) + 1
			tokenLen -= origLen - newLen
			if tokenTruncated || origLen == newLen {
				tokenLen--
			}
823

824
			seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
825

826
			s.removeSequence(i, llm.DoneReasonStop)
827
828
829
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
830
		if common.ContainsStopSuffix(sequence, seq.stop) {
831
832
833
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
834
		if common.IncompleteUnicode(sequence) {
835
836
837
838
			continue
		}

		if !flushPending(seq) {
839
			s.removeSequence(i, llm.DoneReasonConnectionClosed)
840
841
		}
	}
842
843
844
845
846
847
848

	samplingDuration := time.Since(t)
	for i, seq := range s.seqs {
		if seq != nil && nextBatchTokens[i] != nil {
			s.seqs[i].samplingDuration += samplingDuration
		}
	}
849
850
851
}

func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
852
	var req llm.CompletionRequest
853
854
855
856
857
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, "Bad request", http.StatusBadRequest)
		return
	}

858
859
860
861
862
	if req.Options == nil {
		opts := api.DefaultOptions()
		req.Options = &opts
	}

863
864
865
866
867
868
869
870
871
872
	// Set the headers to indicate streaming
	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Transfer-Encoding", "chunked")

	flusher, ok := w.(http.Flusher)
	if !ok {
		http.Error(w, "Streaming not supported", http.StatusInternalServerError)
		return
	}

873
	var grammar *sample.GrammarSampler
874
875
	var err error
	if req.Grammar != "" {
876
		grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
877
878
879
880
		if err != nil {
			http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
			return
		}
881
		defer grammar.Free()
882
883
	}

884
	sampler := sample.NewSampler(
885
886
887
888
889
		req.Options.Temperature,
		req.Options.TopK,
		req.Options.TopP,
		req.Options.MinP,
		req.Options.Seed,
890
		grammar,
891
892
	)

893
	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
894
895
896
897
898
899
900
901
902
		numPredict:  req.Options.NumPredict,
		stop:        req.Options.Stop,
		numKeep:     int32(req.Options.NumKeep),
		sampler:     sampler,
		embedding:   false,
		shift:       req.Shift,
		truncate:    req.Truncate,
		logprobs:    req.Logprobs,
		topLogprobs: req.TopLogprobs,
903
904
	})
	if err != nil {
905
906
907
908
		if errors.Is(err, errorInputTooLong) {
			http.Error(w, err.Error(), http.StatusBadRequest)
			return
		}
909
910
911
912
		http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
		return
	}

913
	// Ensure there is a place to put the sequence, released when removed from s.seqs
914
	if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
915
916
917
		if errors.Is(err, context.Canceled) {
			slog.Info("aborting completion request due to client closing the connection")
		} else {
918
			http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
919
		}
920
921
922
		return
	}

923
	s.mu.Lock()
924
	found := false
925
926
	for i, sq := range s.seqs {
		if sq == nil {
Michael Yang's avatar
Michael Yang committed
927
			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
928
929
			if err != nil {
				s.mu.Unlock()
930
				s.seqsSem.Release(1)
931
932
933
				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
				return
			}
934

935
936
			s.seqs[i] = seq
			s.cond.Signal()
937
			found = true
938
939
940
941
942
			break
		}
	}
	s.mu.Unlock()

943
	if !found {
944
		s.seqsSem.Release(1)
945
946
947
948
		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
		return
	}

949
950
951
952
953
	for {
		select {
		case <-r.Context().Done():
			close(seq.quit)
			return
954
		case resp, ok := <-seq.responses:
955
			if ok {
956
				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
957
958
					Content:  resp.content,
					Logprobs: resp.logprobs,
959
960
961
962
963
964
965
966
				}); err != nil {
					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
					close(seq.quit)
					return
				}

				flusher.Flush()
			} else {
967
968
				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
					Done:               true,
969
					DoneReason:         seq.doneReason,
970
					PromptEvalCount:    seq.numPromptInputs,
971
					PromptEvalDuration: seq.processingDuration,
972
					EvalCount:          seq.numPredicted,
973
					EvalDuration:       seq.lastUpdatedAt.Sub(seq.startedAt) - seq.samplingDuration,
974
975
976
977
978
979
980
981
982
983
				}); err != nil {
					http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
				}

				return
			}
		}
	}
}

Michael Yang's avatar
Michael Yang committed
984
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
Michael Yang's avatar
Michael Yang committed
985
	if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone {
Michael Yang's avatar
Michael Yang committed
986
987
988
989
990
991
992
993
994
995
996
		http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
		return
	}

	var req llm.EmbeddingRequest
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
		return
	}

	w.Header().Set("Content-Type", "application/json")
997
998
	seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
		embedding: true,
999
		truncate:  false,
1000
	})
Michael Yang's avatar
Michael Yang committed
1001
	if err != nil {
1002
1003
1004
1005
		if errors.Is(err, errorInputTooLong) {
			http.Error(w, err.Error(), http.StatusBadRequest)
			return
		}
Michael Yang's avatar
Michael Yang committed
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
		http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
		return
	}

	if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
		if errors.Is(err, context.Canceled) {
			slog.Info("aborting embedding request due to client closing the connection")
		} else {
			http.Error(w, fmt.Sprintf("failed to acquire semaphore: %v", err), http.StatusInternalServerError)
		}
		return
	}

	s.mu.Lock()
	found := false
	for i, sq := range s.seqs {
		if sq == nil {
			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
			if err != nil {
				s.mu.Unlock()
				s.seqsSem.Release(1)
				http.Error(w, fmt.Sprintf("failed to load cache: %v", err), http.StatusInternalServerError)
				return
			}

			s.seqs[i] = seq
			s.cond.Signal()
			found = true
			break
		}
	}
	s.mu.Unlock()

	if !found {
		s.seqsSem.Release(1)
		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
		return
	}

	if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
1046
1047
		Embedding:       <-seq.embedding,
		PromptEvalCount: seq.numPromptInputs,
Michael Yang's avatar
Michael Yang committed
1048
1049
1050
1051
1052
	}); err != nil {
		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
	}
}

1053
1054
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "application/json")
1055
1056
	if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
		Status:   s.status,
1057
1058
1059
1060
1061
1062
		Progress: s.progress,
	}); err != nil {
		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
	}
}

1063
func (s *Server) reserveWorstCaseGraph(prompt bool) error {
1064
1065
1066
	ctx := s.model.Backend().NewContext()
	defer ctx.Close()

1067
	var err error
1068
1069
1070
1071
1072
1073
	batchSize := 1
	if prompt {
		batchSize = s.batchSize
	}

	inputs := make([]*input.Input, batchSize)
1074
1075
1076
	for i := range inputs {
		inputs[i] = &input.Input{}
	}
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
	mmStore := newMultimodalStore()

	// Multimodal strategy:
	// - Encode a 2048x2048 image. This assumes that a single image of this
	//   size is sufficient to trigger the worst case. This is currently true
	//   because for existing models, only a single image fits in a batch.
	// - Add the embedding to a full batch of tokens - this is necessary because
	//   the model may be looking for non-image data, such as <image> tags.
	// - Run PostTokenize to execute any transformations between generated
	//   embeddings and what the forward pass expects.
	// - The result may now be larger than a batch (images may not fit in a
	//   single batch), so trim based on what will fit and must be grouped together.
	// - Fill out the rest of the space with text tokens.
1090
	if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); prompt && ok {
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
		mmCtx := s.model.Backend().NewContext()
		defer mmCtx.Close()

		img := image.NewGray(image.Rect(0, 0, 2048, 2048))
		var buf bytes.Buffer
		bmp.Encode(&buf, img)

		if inputs[0].Multimodal, err = multimodalProcessor.EncodeMultimodal(mmCtx, buf.Bytes()); err == nil {
			mmStore.addMultimodal(inputs[0].Multimodal)

			inputs, err = multimodalProcessor.PostTokenize(inputs)
			if err != nil {
				return err
			}

			for i, inp := range inputs {
				minBatch := 1 + inp.SameBatch
				if minBatch > s.batchSize {
					inputs = inputs[i:min(i+minBatch, len(inputs))]
					break
				} else if i+minBatch > s.batchSize {
					inputs = inputs[:i]
					break
				}
			}

1117
1118
			if len(inputs) < batchSize {
				newInputs := make([]*input.Input, batchSize)
1119
				copy(newInputs, inputs)
1120
				for i := len(inputs); i < batchSize; i++ {
1121
1122
					newInputs[i] = &input.Input{}
				}
1123
1124
1125
1126
1127
				inputs = newInputs
			}
		}
	}

1128
1129
	var batch input.Batch

1130
	batchInputs := make([]int32, len(inputs))
1131
1132
	batch.Positions = make([]int32, len(inputs))
	batch.Sequences = make([]int, len(inputs))
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
	for i, inp := range inputs {
		batchInputs[i] = inp.Token
		if inp.Multimodal != nil {
			mm, err := mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, true)
			if err != nil {
				return err
			}
			batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: i, Multimodal: mm})
		}

1143
1144
1145
		batch.Positions[i] = int32(i)
	}

Michael Yang's avatar
Michael Yang committed
1146
	batch.Inputs = ctx.Input().FromInts(batchInputs, len(batchInputs))
1147
	batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161

	cache := s.model.Config().Cache
	if cache != nil {
		err := cache.StartForward(ctx, batch, true)
		if err != nil {
			return err
		}
	}

	t, err := s.model.Forward(ctx, batch)
	if err != nil {
		return err
	}

1162
	ctx.SetBatchSize(batchSize)
1163
	ctx.Forward(t).Reserve()
1164
1165

	return nil
1166
}
1167

Jesse Gross's avatar
Jesse Gross committed
1168
1169
1170
// allocModel pre-allocates the maximum needed memory for a model
// based on the given parameters
func (s *Server) allocModel(
1171
	mpath string,
1172
	params ml.BackendParams,
Jesse Gross's avatar
Jesse Gross committed
1173
	loraPath []string,
Jesse Gross's avatar
Jesse Gross committed
1174
	parallel int,
1175
	kvCacheType string,
Jesse Gross's avatar
Jesse Gross committed
1176
	kvSize int,
1177
	multiUserCache bool,
Jesse Gross's avatar
Jesse Gross committed
1178
1179
1180
1181
1182
) (panicErr error) {
	// Convert memory allocation panics to errors
	defer func() {
		if r := recover(); r != nil {
			if err, ok := r.(error); ok {
1183
1184
1185
1186
1187
1188
				var noMem ml.ErrNoMem
				if errors.As(err, &noMem) {
					panicErr = noMem
				} else {
					panic(r)
				}
Jesse Gross's avatar
Jesse Gross committed
1189
1190
1191
1192
1193
1194
			} else {
				panic(r)
			}
		}
	}()

1195
	var err error
1196
	s.model, err = model.New(mpath, params)
1197
	if err != nil {
1198
		return err
1199
	}
1200

Jesse Gross's avatar
Jesse Gross committed
1201
	// TODO(jessegross): LoRA loading
Jesse Gross's avatar
Jesse Gross committed
1202
	if len(loraPath) > 0 {
1203
		return errors.New("loras are not yet implemented")
1204
1205
	}

1206
	s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
1207
	if err != nil {
1208
		return err
1209
	}
1210

Jesse Gross's avatar
Jesse Gross committed
1211
1212
1213
1214
1215
1216
1217
1218
1219
	if !s.cache.enabled && parallel > 1 {
		parallel = 1
		slog.Warn("model does not support caching, disabling parallel processing")
	}

	s.parallel = parallel
	s.seqs = make([]*Sequence, s.parallel)
	s.seqsSem = semaphore.NewWeighted(int64(s.parallel))

1220
1221
1222
1223
1224
1225
	err = s.reserveWorstCaseGraph(true)
	if err != nil {
		return nil
	}

	return s.reserveWorstCaseGraph(false)
1226
1227
}

Jesse Gross's avatar
Jesse Gross committed
1228
1229
1230
1231
1232
1233
1234
// closeModel frees all memory associated with a model
func (s *Server) closeModel() {
	s.cache.Close()
	s.cache = nil
	if s.model != nil {
		s.model.Backend().Close()
		s.model = nil
1235
	}
Jesse Gross's avatar
Jesse Gross committed
1236
}
1237

Jesse Gross's avatar
Jesse Gross committed
1238
1239
1240
1241
// loadModel loads the weights for a model. The memory must already
// have been allocated with allocModel
func (s *Server) loadModel() {
	err := s.model.Backend().Load(context.TODO(),
1242
1243
1244
1245
		func(progress float32) {
			s.progress = progress
		})
	if err != nil {
Jesse Gross's avatar
Jesse Gross committed
1246
		panic(fmt.Errorf("failed to load model: %v", err))
1247
1248
	}

1249
	s.status = llm.ServerStatusReady
1250
1251
1252
	s.ready.Done()
}

Jesse Gross's avatar
Jesse Gross committed
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
// load is the handler called by the Ollama server to process different
// load operations
func (s *Server) load(w http.ResponseWriter, r *http.Request) {
	s.loadMu.Lock()
	defer s.loadMu.Unlock()

	w.Header().Set("Content-Type", "application/json")

	if s.status != llm.ServerStatusLaunched {
		http.Error(w, "model already loaded", http.StatusInternalServerError)
		return
	}

	var req llm.LoadRequest
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, "bad request", http.StatusBadRequest)
		return
	}

	slog.Info("load", "request", req)

	if req.Operation == llm.LoadOperationClose {
		s.closeModel()
		if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil {
			http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
		}
		return
	}

	s.lastLoad.Operation = req.Operation
	loadModel := s.model == nil || !reflect.DeepEqual(req, s.lastLoad)

	s.lastLoad = req

	if loadModel {
		s.closeModel()

		params := ml.BackendParams{
			AllocMemory:    req.Operation != llm.LoadOperationFit,
			NumThreads:     req.NumThreads,
			GPULayers:      req.GPULayers,
			FlashAttention: req.FlashAttention,
		}

		s.batchSize = req.BatchSize

		err := s.allocModel(s.modelPath, params, req.LoraPath, req.Parallel, req.KvCacheType, req.KvSize, req.MultiUserCache)
		if err != nil {
			s.closeModel()

			var noMem ml.ErrNoMem
			if errors.As(err, &noMem) {
				resp := llm.LoadResponse{Success: false, Memory: noMem.BackendMemory}
				if err := json.NewEncoder(w).Encode(&resp); err != nil {
					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
				}

				return
			}

			http.Error(w, fmt.Sprintf("failed to initialize model: %v", err), http.StatusInternalServerError)
			return
		}
	}

	mem := s.model.Backend().BackendMemory()

	switch req.Operation {
	case llm.LoadOperationFit:
		// LoadOperationFit can't be used for anything else, so just close it
		s.closeModel()

	// LoadOperationAlloc should stay open for future operations

	case llm.LoadOperationCommit:
		s.status = llm.ServerStatusLoadingModel
		go s.loadModel()
	}

	resp := llm.LoadResponse{Success: true, Memory: mem}
	if err := json.NewEncoder(w).Encode(&resp); err != nil {
		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
		return
	}
}

1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
// info is the handler called by the Ollama server to report information
// about the GPU devices in use by this runner
func (s *Server) info(w http.ResponseWriter, r *http.Request) {
	s.loadMu.Lock()
	defer s.loadMu.Unlock()

	w.Header().Set("Content-Type", "application/json")

	m := s.model

	if m == nil {
		startLoad := time.Now()

		// Dummy load to get the backend wired up
		f, err := os.CreateTemp("", "*.bin")
		if err != nil {
			http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
			return
		}
		defer f.Close()
		defer os.Remove(f.Name())

		if err := ggml.WriteGGUF(f, ggml.KV{
			"general.architecture": "llama",
			"tokenizer.ggml.model": "gpt2",
		}, nil); err != nil {
			http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
			return
		}

		m, err = model.New(f.Name(), ml.BackendParams{NumThreads: runtime.NumCPU(), AllocMemory: false, GPULayers: ml.GPULayersList{{}}})
		if err != nil {
			http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
			return
		}
		slog.Debug("dummy model load took", "duration", time.Since(startLoad))
	}

	startDevices := time.Now()
	infos := m.Backend().BackendDevices()
	slog.Debug("gathering device infos took", "duration", time.Since(startDevices))
	if err := json.NewEncoder(w).Encode(&infos); err != nil {
		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
	}
}

1385
1386
1387
1388
func Execute(args []string) error {
	fs := flag.NewFlagSet("runner", flag.ExitOnError)
	mpath := fs.String("model", "", "Path to model binary file")
	port := fs.Int("port", 8080, "Port to expose the server on")
1389
	_ = fs.Bool("verbose", false, "verbose output (default: disabled)")
1390

1391
1392
1393
1394
1395
1396
	fs.Usage = func() {
		fmt.Fprintf(fs.Output(), "Runner usage\n")
		fs.PrintDefaults()
	}
	if err := fs.Parse(args); err != nil {
		return err
1397
	}
1398
	slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
Jesse Gross's avatar
Jesse Gross committed
1399
	slog.Info("starting ollama engine")
1400

1401
1402
1403
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

Jesse Gross's avatar
Jesse Gross committed
1404
1405
1406
	server := &Server{
		modelPath: *mpath,
		status:    llm.ServerStatusLaunched,
1407
1408
	}

Jesse Gross's avatar
Jesse Gross committed
1409
1410
	server.cond = sync.NewCond(&server.mu)
	server.ready.Add(1)
1411
1412
1413
1414
1415
1416
1417

	go server.run(ctx)

	addr := "127.0.0.1:" + strconv.Itoa(*port)
	listener, err := net.Listen("tcp", addr)
	if err != nil {
		fmt.Println("Listen error:", err)
1418
		return err
1419
1420
1421
1422
	}
	defer listener.Close()

	mux := http.NewServeMux()
1423
	// TODO: support embeddings
1424
	mux.HandleFunc("GET /info", server.info)
Jesse Gross's avatar
Jesse Gross committed
1425
	mux.HandleFunc("POST /load", server.load)
Michael Yang's avatar
Michael Yang committed
1426
	mux.HandleFunc("POST /embedding", server.embeddings)
1427
1428
	mux.HandleFunc("POST /completion", server.completion)
	mux.HandleFunc("GET /health", server.health)
1429
1430
1431
1432
1433
1434
1435
1436

	httpServer := http.Server{
		Handler: mux,
	}

	log.Println("Server listening on", addr)
	if err := httpServer.Serve(listener); err != nil {
		log.Fatal("server error:", err)
1437
		return err
1438
1439
	}

1440
	return nil
1441
}