runner.go 23.3 KB
Newer Older
Jesse Gross's avatar
Jesse Gross committed
1
package ollamarunner
2
3
4
5
6
7
8

import (
	"context"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
9
	"hash/maphash"
10
11
12
13
14
15
16
17
18
19
20
	"log"
	"log/slog"
	"net"
	"net/http"
	"os"
	"regexp"
	"runtime"
	"strconv"
	"strings"
	"sync"
	"time"
21
	"unicode/utf8"
22

23
24
	"golang.org/x/sync/semaphore"

25
	"github.com/ollama/ollama/api"
26
	"github.com/ollama/ollama/envconfig"
27
	"github.com/ollama/ollama/llm"
28
	"github.com/ollama/ollama/logutil"
29
	"github.com/ollama/ollama/ml"
Jesse Gross's avatar
Jesse Gross committed
30
	"github.com/ollama/ollama/model"
31
	"github.com/ollama/ollama/model/input"
Jesse Gross's avatar
Jesse Gross committed
32
33
34
35
	"github.com/ollama/ollama/runner/common"
	"github.com/ollama/ollama/sample"

	_ "github.com/ollama/ollama/model/models"
36
37
38
)

type Sequence struct {
39
	// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
40
	// multimodal embeddings
41
	ctxs []ml.Context
42

43
44
45
	// mmStore holds multimodal embeddings to mange memory and enable splitting across batches
	mmStore multimodalStore

46
47
48
49
	// batch index
	iBatch int

	// prompt inputs left to evaluate
50
	inputs []input.Input
51

Jesse Gross's avatar
Jesse Gross committed
52
	// inputs that have been added to a batch but not yet submitted to Forward
53
	pendingInputs []input.Input
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
	// tokens that have been generated but not returned yet (e.g. for stop sequences)
	pendingResponses []string

	// input cache being used by this sequence
	cache *InputCacheSlot

	// channel to send responses over
	responses chan string

	// channel to stop decoding (such as if the remote connection is closed)
	quit chan bool

	// number of tokens to predict
	numPredict int

70
71
	// sampler with transforms to run on generated logits
	sampler sample.Sampler
72
73
74
75
76
77
78
79

	// channel to send back the embedding if embedding only
	embedding chan []float32

	// stop sequences
	stop []string

	// number of inputs to keep at the beginning when shifting context window
Jesse Gross's avatar
Jesse Gross committed
80
	numKeep int32
81
82
83
84

	// true if an embedding are to be returned instead of text generation
	embeddingOnly bool

85
	doneReason llm.DoneReason
86
87
88
89

	// Metrics
	startProcessingTime time.Time
	startGenerationTime time.Time
Jesse Gross's avatar
Jesse Gross committed
90
	numPredicted        int
91
92
93
94
	numPromptInputs     int
}

type NewSequenceParams struct {
Jesse Gross's avatar
Jesse Gross committed
95
96
97
	numPredict int
	stop       []string
	numKeep    int32
98
	sampler    sample.Sampler
Jesse Gross's avatar
Jesse Gross committed
99
	embedding  bool
100
101
}

102
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
103
104
105
106
	s.ready.Wait()

	startTime := time.Now()

107
	inputs, ctxs, mmStore, err := s.inputs(prompt, images)
108
109
110
111
112
113
114
	if err != nil {
		return nil, fmt.Errorf("failed to process inputs: %w", err)
	} else if len(inputs) == 0 {
		return nil, errors.New("no input provided")
	}

	if params.numKeep < 0 {
Jesse Gross's avatar
Jesse Gross committed
115
		params.numKeep = int32(len(inputs))
116
117
	}

118
119
120
	// Ensure that at least 1 input can be discarded during shift
	params.numKeep = min(params.numKeep, s.cache.numCtx-1)

Jesse Gross's avatar
Jesse Gross committed
121
122
	if int32(len(inputs)) > s.cache.numCtx {
		discard := int32(len(inputs)) - s.cache.numCtx
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
		promptStart := params.numKeep + discard

		// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
		sameBatch := 0
		for i, inp := range inputs {
			if sameBatch > 0 {
				sameBatch--

				if promptStart == int32(i) {
					promptStart++
				}
			} else if promptStart == int32(i) {
				break
			}

			if inp.SameBatch != 0 {
				if int32(i) < params.numKeep {
					return nil, fmt.Errorf("SameBatch may not be specified within numKeep (index: %v numKeep: %v SameBatch: %v)", i, params.numKeep, inp.SameBatch)
				}

				sameBatch = inp.SameBatch
			}
		}

		if promptStart >= int32(len(inputs)) {
			return nil, errors.New("entire prompt removed by truncation")
		}

151
		newInputs := inputs[:params.numKeep]
152
		newInputs = append(newInputs, inputs[promptStart:]...)
153
154

		slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
155
		inputs = newInputs
156
157
	}

Jesse Gross's avatar
Jesse Gross committed
158
	// TODO(jessegross): Ingest cached history for grammar
159
160

	return &Sequence{
161
		ctxs:                ctxs,
162
		mmStore:             mmStore,
163
164
165
166
167
168
169
170
		inputs:              inputs,
		numPromptInputs:     len(inputs),
		startProcessingTime: startTime,
		numPredict:          params.numPredict,
		pendingResponses:    make([]string, 0),
		responses:           make(chan string, 100),
		quit:                make(chan bool, 1),
		embedding:           make(chan []float32, 1),
171
		sampler:             params.sampler,
172
173
174
175
176
177
178
179
		embeddingOnly:       params.embedding,
		stop:                params.stop,
		numKeep:             params.numKeep,
	}, nil
}

// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
Jesse Gross's avatar
Jesse Gross committed
180
// decoding images
181
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
182
	var inputs []input.Input
183
	var ctxs []ml.Context
184
	var mmStore multimodalStore
185

186
187
188
	var parts []string
	var matches [][]string

189
	multimodalProcessor, visionModel := s.model.(model.MultimodalProcessor)
190

191
192
193
194
	if visionModel {
		re := regexp.MustCompile(`\[img-(\d+)\]`)
		parts = re.Split(prompt, -1)
		matches = re.FindAllStringSubmatch(prompt, -1)
195
		mmStore = newMultimodalStore()
196
197
198
199
200
	} else {
		parts = []string{prompt}
	}

	postTokenize := false
201
202
	for i, part := range parts {
		// text - tokenize
203
		tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
204
		if err != nil {
205
			return nil, nil, nil, err
206
		}
207

208
		for _, t := range tokens {
209
			inputs = append(inputs, input.Input{Token: t})
210
211
		}

Jesse Gross's avatar
Jesse Gross committed
212
		// image - decode and store
213
214
215
216
217
218
219
220
221
222
223
224
		if i < len(matches) {
			n, _ := strconv.Atoi(matches[i][1])

			imageIndex := -1
			for j := range images {
				if images[j].ID == n {
					imageIndex = j
					break
				}
			}

			if imageIndex < 0 {
225
				return nil, nil, nil, fmt.Errorf("invalid image index: %d", n)
226
227
			}

228
			ctx := s.model.Backend().NewContext()
229
230
			runtime.SetFinalizer(ctx, func(c ml.Context) { c.Close() })
			ctxs = append(ctxs, ctx)
231
			imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
Jesse Gross's avatar
Jesse Gross committed
232
			if err != nil {
233
				return nil, nil, nil, err
Jesse Gross's avatar
Jesse Gross committed
234
235
			}

236
237
238
239
			s.multimodalHash.Reset()
			_, _ = s.multimodalHash.Write(images[imageIndex].Data)
			imageHash := s.multimodalHash.Sum64()

240
241
			mmStore.addMultimodal(imageEmbeddings)

242
			inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
243
244
245
246
247
248
			postTokenize = true
		}
	}

	if visionModel && postTokenize {
		var err error
249
		inputs, err = multimodalProcessor.PostTokenize(inputs)
250
		if err != nil {
251
			return nil, nil, nil, err
252
253
254
		}
	}

255
	return inputs, ctxs, mmStore, nil
256
257
258
}

type Server struct {
259
260
261
262
263
	// is the server ready to process requests?
	// protects access to model and image
	ready sync.WaitGroup

	// loaded model
Jesse Gross's avatar
Jesse Gross committed
264
	model model.Model
265

266
	// status for external health reporting - loading, ready to serve, etc.
267
	status llm.ServerStatus
268
269
270
271
272
273
274
275

	// current progress on loading the model
	progress float32

	// number of simultaneous requests to handle
	parallel int

	// maximum number of elements in a batch (per sequence)
276
	// TODO (jmorganca): make this n_batch
277
278
	batchSize int

279
280
281
282
283
284
285
286
	// protects access to everything below this line
	// this is context state needed for decoding
	mu sync.Mutex

	// indicates that data is ready for processing
	cond *sync.Cond

	// the list of simultaneous sequences being evaluated
287
288
	seqs []*Sequence

289
290
291
292
	// seqs can have a maximum of parallel entries, which
	// is enfoced by seqSem
	seqsSem *semaphore.Weighted

293
294
295
	// KV cache
	cache *InputCache

296
297
298
	// next sequence for prompt processing to avoid starvation
	nextSeq int

299
300
301
	// multimodalHash generates hashes for comparing equality
	// of non-text data
	multimodalHash maphash.Hash
302
303
304
305
306
307
308
309
310
311
312
313
}

func (s *Server) allNil() bool {
	for _, item := range s.seqs {
		if item != nil {
			return false
		}
	}
	return true
}

func flushPending(seq *Sequence) bool {
314
315
316
317
318
319
320
321
322
323
324
	joined := strings.Join(seq.pendingResponses, "")
	seq.pendingResponses = []string{}

	// Check if there are any partial UTF-8 characters remaining.
	// We already check and queue as we are generating but some may
	// still make it here:
	// - Sequence is ending, e.g. generation limit has been hit
	// - Invalid characters in the middle of a string
	// This is a stricter check to ensure we never output invalid Unicode.
	for !utf8.ValidString(joined) {
		joined = joined[:len(joined)-1]
325
326
	}

327
328
329
330
331
332
333
334
335
336
	if len(joined) == 0 {
		return true
	}

	select {
	case seq.responses <- joined:
		return true
	case <-seq.quit:
		return false
	}
337
338
}

339
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
340
341
342
343
344
345
346
347
	seq := s.seqs[seqIndex]

	flushPending(seq)
	seq.doneReason = reason
	close(seq.responses)
	close(seq.embedding)
	seq.cache.InUse = false
	s.seqs[seqIndex] = nil
348
	s.seqsSem.Release(1)
349
350
351
352
353
354
355
356
357
358
}

func (s *Server) run(ctx context.Context) {
	s.ready.Wait()

	for {
		select {
		case <-ctx.Done():
			return
		default:
Jesse Gross's avatar
Jesse Gross committed
359
			err := s.processBatch()
360
361
362
			if err != nil {
				panic(err)
			}
363
364
365
366
		}
	}
}

Jesse Gross's avatar
Jesse Gross committed
367
func (s *Server) processBatch() error {
368
369
370
371
372
373
	s.mu.Lock()
	for s.allNil() {
		s.cond.Wait() // Wait until an item is added
	}
	defer s.mu.Unlock()

374
375
376
	ctx := s.model.Backend().NewContext()
	defer ctx.Close()

377
	var batchInputs []int32
Jesse Gross's avatar
Jesse Gross committed
378
	var batch input.Batch
379

380
381
382
383
384
385
	resumeSeq := -1
	seqIdx := s.nextSeq - 1
	for range s.seqs {
		seqIdx = (seqIdx + 1) % len(s.seqs)
		seq := s.seqs[seqIdx]

386
387
388
389
390
		if seq == nil {
			continue
		}

		// if past the num predict limit
391
		if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
392
			s.removeSequence(seqIdx, llm.DoneReasonLength)
393
394
395
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
396
397
		if !s.cache.enabled {
			seq.inputs = append(seq.cache.Inputs, seq.inputs...)
398
			seq.cache.Inputs = []input.Input{}
Jesse Gross's avatar
Jesse Gross committed
399
400
		}

401
402
		batchSize := s.batchSize

403
		for i, inp := range seq.inputs {
404
405
			// If we are required to put following inputs into a single batch then extend the
			// batch size. Since we are only extending the size the minimum amount possible, this
406
			// will cause a break if we have existing inputs.
407
408
409
410
411
			minBatch := 1 + inp.SameBatch
			if minBatch > batchSize {
				batchSize = minBatch
			}

412
413
414
415
416
417
418
419
			// Stop if the required batch would put us over the total batch size (including tokens
			// added by other sequences). If we haven't been able to add anything yet then pick up
			// here again for the next batch to avoid starvation, though we can opportunistically
			// check if other sequences can still squeeze something in.
			if len(batchInputs)+minBatch > batchSize {
				if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
					resumeSeq = seqIdx
				}
420
421
				break
			}
Jesse Gross's avatar
Jesse Gross committed
422

423
424
425
426
427
428
429
430
431
432
			// If the sum of our working set (already processed tokens, tokens we added to this
			// batch, required following tokens) exceeds the context size, then trigger a shift
			// now so we don't have to do one later when we can't break the batch.
			if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
				if len(seq.pendingInputs) != 0 {
					break
				}

				err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
				if err != nil {
433
434
435
436
437
438
439
440
441
					var reprocess *ErrReprocessInputs
					if errors.As(err, &reprocess) {
						// Prepend these inputs to the sequence's inputs queue for reprocessing
						seq.inputs = append(reprocess.Inputs, seq.inputs...)
						// Skip this sequence but continue processing the rest
						continue
					} else {
						return err
					}
442
443
444
				}
			}

445
			batchInputs = append(batchInputs, inp.Token)
446
			if inp.Multimodal != nil {
447
448
449
450
451
				mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal)
				if err != nil {
					return err
				}
				batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
452
453
			}

Jesse Gross's avatar
Jesse Gross committed
454
455
			batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
			batch.Sequences = append(batch.Sequences, seq.cache.Id)
Jesse Gross's avatar
Jesse Gross committed
456

Jesse Gross's avatar
Jesse Gross committed
457
			seq.iBatch = len(batch.Outputs)
458
			if i+1 == len(seq.inputs) {
459
				batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
Jesse Gross's avatar
Jesse Gross committed
460
			}
461
			seq.pendingInputs = append(seq.pendingInputs, inp)
462
		}
463
464

		seq.inputs = seq.inputs[len(seq.pendingInputs):]
465
466
	}

467
468
469
470
471
472
	if resumeSeq != -1 {
		s.nextSeq = resumeSeq
	} else {
		s.nextSeq = seqIdx + 1
	}

473
	if len(batchInputs) == 0 {
474
		return nil
475
476
	}

477
	modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
478
	if err != nil {
479
		return fmt.Errorf("failed to decode batch: %w", err)
480
481
	}

482
	logits := modelOutput.Floats()
483

484
485
486
487
488
	for i, seq := range s.seqs {
		if seq == nil {
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
489
		// After calling Forward, pending inputs are now in the cache
490
491
		if len(seq.pendingInputs) > 0 {
			seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
492
			seq.pendingInputs = []input.Input{}
493
494
		}

495
496
		// don't sample prompt processing
		if len(seq.inputs) != 0 {
Jesse Gross's avatar
Jesse Gross committed
497
498
499
			if !s.cache.enabled {
				return errors.New("caching disabled but unable to fit entire input in a batch")
			}
500
501
502
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
503
504
		seq.numPredicted++
		if seq.numPredicted == 1 {
505
506
507
508
509
			seq.startGenerationTime = time.Now()
		}

		// if done processing the prompt, generate an embedding and return
		if seq.embeddingOnly {
Jesse Gross's avatar
Jesse Gross committed
510
			// TODO(jessegross): Embedding support
511
			slog.Warn("generation of embedding outputs not yet supported")
512
			s.removeSequence(i, llm.DoneReasonStop)
513
			continue
514
515
516
		}

		// sample a token
Jesse Gross's avatar
Jesse Gross committed
517
		vocabSize := len(logits) / len(batch.Outputs)
518
519

		token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
Jesse Gross's avatar
Jesse Gross committed
520
		if err != nil {
521
			return fmt.Errorf("failed to sample token: %w", err)
Jesse Gross's avatar
Jesse Gross committed
522
		}
523
524

		// if it's an end of sequence token, break
Jesse Gross's avatar
Jesse Gross committed
525
		if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
526
527
528
529
			// TODO (jmorganca): we should send this back
			// as it's important for the /api/generate context
			// seq.responses <- piece

530
			s.removeSequence(i, llm.DoneReasonStop)
531
532
533
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
534
535
536
537
538
		piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
		if err != nil {
			return err
		}

539
		seq.inputs = []input.Input{{Token: token}}
540
541
542
543

		seq.pendingResponses = append(seq.pendingResponses, piece)
		sequence := strings.Join(seq.pendingResponses, "")

Jesse Gross's avatar
Jesse Gross committed
544
		if ok, stop := common.FindStop(sequence, seq.stop); ok {
545
546
547
548
			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)

			var tokenTruncated bool
			origLen := len(seq.pendingResponses)
Jesse Gross's avatar
Jesse Gross committed
549
			seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop)
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
			newLen := len(seq.pendingResponses)

			// Update the cache based on the tokens that will be returned:
			// - We have 1 token more than is currently in the cache because
			// the last one generated wasn't submitted to Decode
			// - Remove any stop sequences that we stripped out
			// - If truncateStop removed a portion of a token, drop that
			// - As defense-in-depth, if truncatedToken didn't find a stop token
			// remove the extra one that we added to the cache len
			tokenLen := len(seq.cache.Inputs) + 1
			tokenLen -= origLen - newLen
			if tokenTruncated || origLen == newLen {
				tokenLen--
			}
			seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
565

566
			s.removeSequence(i, llm.DoneReasonStop)
567
568
569
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
570
		if common.ContainsStopSuffix(sequence, seq.stop) {
571
572
573
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
574
		if common.IncompleteUnicode(sequence) {
575
576
577
578
			continue
		}

		if !flushPending(seq) {
579
			s.removeSequence(i, llm.DoneReasonConnectionClosed)
580
581
		}
	}
582
583

	return nil
584
585
586
}

func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
587
	var req llm.CompletionRequest
588
589
590
591
592
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, "Bad request", http.StatusBadRequest)
		return
	}

593
594
595
596
597
	if req.Options == nil {
		opts := api.DefaultOptions()
		req.Options = &opts
	}

598
599
600
601
602
603
604
605
606
607
	// Set the headers to indicate streaming
	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Transfer-Encoding", "chunked")

	flusher, ok := w.(http.Flusher)
	if !ok {
		http.Error(w, "Streaming not supported", http.StatusInternalServerError)
		return
	}

608
	var grammar *sample.GrammarSampler
609
610
	var err error
	if req.Grammar != "" {
611
		grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
612
613
614
615
		if err != nil {
			http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
			return
		}
616
		defer grammar.Free()
617
618
	}

619
	sampler := sample.NewSampler(
620
621
622
623
624
		req.Options.Temperature,
		req.Options.TopK,
		req.Options.TopP,
		req.Options.MinP,
		req.Options.Seed,
625
		grammar,
626
627
	)

628
	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
629
630
631
		numPredict: req.Options.NumPredict,
		stop:       req.Options.Stop,
		numKeep:    int32(req.Options.NumKeep),
632
		sampler:    sampler,
Jesse Gross's avatar
Jesse Gross committed
633
		embedding:  false,
634
635
636
637
638
639
	})
	if err != nil {
		http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
		return
	}

640
	// Ensure there is a place to put the sequence, released when removed from s.seqs
641
	if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
642
643
644
		if errors.Is(err, context.Canceled) {
			slog.Info("aborting completion request due to client closing the connection")
		} else {
645
			http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
646
		}
647
648
649
		return
	}

650
	s.mu.Lock()
651
	found := false
652
653
	for i, sq := range s.seqs {
		if sq == nil {
654
			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
655
656
			if err != nil {
				s.mu.Unlock()
657
				s.seqsSem.Release(1)
658
659
660
				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
				return
			}
661

662
663
			s.seqs[i] = seq
			s.cond.Signal()
664
			found = true
665
666
667
668
669
			break
		}
	}
	s.mu.Unlock()

670
	if !found {
671
		s.seqsSem.Release(1)
672
673
674
675
		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
		return
	}

676
677
678
679
680
681
682
	for {
		select {
		case <-r.Context().Done():
			close(seq.quit)
			return
		case content, ok := <-seq.responses:
			if ok {
683
				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
684
685
686
687
688
689
690
691
692
					Content: content,
				}); err != nil {
					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
					close(seq.quit)
					return
				}

				flusher.Flush()
			} else {
693
694
				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
					Done:               true,
695
					DoneReason:         seq.doneReason,
696
697
698
699
					PromptEvalCount:    seq.numPromptInputs,
					PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
					EvalCount:          seq.numPredicted,
					EvalDuration:       time.Since(seq.startGenerationTime),
700
701
702
703
704
705
706
707
708
709
710
711
				}); err != nil {
					http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
				}

				return
			}
		}
	}
}

func (s *Server) health(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "application/json")
712
713
	if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
		Status:   s.status,
714
715
716
717
718
719
		Progress: s.progress,
	}); err != nil {
		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
	}
}

720
721
722
723
724
725
726
727
728
729
730
type multiLPath []string

func (m *multiLPath) Set(value string) error {
	*m = append(*m, value)
	return nil
}

func (m *multiLPath) String() string {
	return strings.Join(*m, ", ")
}

731
func (s *Server) reserveWorstCaseGraph() error {
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
	ctx := s.model.Backend().NewContext()
	defer ctx.Close()

	var batch input.Batch

	inputs := make([]int32, s.batchSize)
	batch.Positions = make([]int32, len(inputs))
	batch.Sequences = make([]int, len(inputs))
	for i := range inputs {
		batch.Positions[i] = int32(i)
	}

	batch.Outputs = make([]int32, s.parallel)
	for i := range batch.Outputs {
		batch.Outputs[i] = int32(i)
	}

	var err error
	batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
	if err != nil {
		return err
	}

	cache := s.model.Config().Cache
	if cache != nil {
		err := cache.StartForward(ctx, batch, true)
		if err != nil {
			return err
		}
	}

	t, err := s.model.Forward(ctx, batch)
	if err != nil {
		return err
	}

	err = ctx.Forward(t).Reserve()
	if err != nil {
		return err
	}

	return nil
774
}
775

776
func (s *Server) loadModel(
777
	ctx context.Context,
778
	mpath string,
779
	params ml.BackendParams,
780
	lpath multiLPath,
Jesse Gross's avatar
Jesse Gross committed
781
	parallel int,
782
	kvCacheType string,
Jesse Gross's avatar
Jesse Gross committed
783
	kvSize int,
784
785
	multiUserCache bool,
) {
786
	var err error
787
	s.model, err = model.New(ctx, mpath, params)
788
789
790
	if err != nil {
		panic(err)
	}
791

Jesse Gross's avatar
Jesse Gross committed
792
	// TODO(jessegross): LoRA loading
793
	if lpath.String() != "" {
Jesse Gross's avatar
Jesse Gross committed
794
		panic("loras are not yet implemented")
795
796
	}

797
	s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
798
799
800
	if err != nil {
		panic(err)
	}
801

Jesse Gross's avatar
Jesse Gross committed
802
803
804
805
806
807
808
809
810
	if !s.cache.enabled && parallel > 1 {
		parallel = 1
		slog.Warn("model does not support caching, disabling parallel processing")
	}

	s.parallel = parallel
	s.seqs = make([]*Sequence, s.parallel)
	s.seqsSem = semaphore.NewWeighted(int64(s.parallel))

811
	err = s.reserveWorstCaseGraph()
812
813
	if err != nil {
		panic(err)
814
	}
815

816
	s.status = llm.ServerStatusReady
817
818
819
	s.ready.Done()
}

820
821
822
823
824
func Execute(args []string) error {
	fs := flag.NewFlagSet("runner", flag.ExitOnError)
	mpath := fs.String("model", "", "Path to model binary file")
	parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously")
	batchSize := fs.Int("batch-size", 512, "Batch size")
825
826
	numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
	mainGPU := fs.Int("main-gpu", 0, "Main GPU")
827
	flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
828
829
830
	kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
	kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
	port := fs.Int("port", 8080, "Port to expose the server on")
831
	threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
832
	_ = fs.Bool("verbose", false, "verbose output (default: disabled)")
Jesse Gross's avatar
Jesse Gross committed
833
	_ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
834
	tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
835
	multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
836

837
	var lpaths multiLPath
838
	fs.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
839

840
841
842
843
844
845
	fs.Usage = func() {
		fmt.Fprintf(fs.Output(), "Runner usage\n")
		fs.PrintDefaults()
	}
	if err := fs.Parse(args); err != nil {
		return err
846
	}
847
	slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
Jesse Gross's avatar
Jesse Gross committed
848
	slog.Info("starting ollama engine")
849
850
851

	server := &Server{
		batchSize: *batchSize,
852
		status:    llm.ServerStatusLoadingModel,
853
854
	}

Jesse Gross's avatar
Jesse Gross committed
855
856
857
858
	// TODO(jessegross): Parameters that need to be implemented:
	//	no-mmap
	//	mlock

859
	var tensorSplitFloats []float32
860
	if *tensorSplit != "" {
861
862
863
		splits := strings.Split(*tensorSplit, ",")
		tensorSplitFloats = make([]float32, len(splits))
		for i, s := range splits {
864
			f, _ := strconv.ParseFloat(s, 32)
865
			tensorSplitFloats[i] = float32(f)
866
		}
867
868
869
	}

	params := ml.BackendParams{
870
871
872
		Progress: func(progress float32) {
			server.progress = progress
		},
873
874
875
876
877
		NumThreads:     *threads,
		NumGPULayers:   *numGPULayers,
		MainGPU:        *mainGPU,
		TensorSplit:    tensorSplitFloats,
		FlashAttention: *flashAttention,
878
	}
879
880
881

	server.ready.Add(1)
	ctx, cancel := context.WithCancel(context.Background())
Michael Yang's avatar
Michael Yang committed
882
883
	defer cancel()

884
885
886
887
	go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)

	server.cond = sync.NewCond(&server.mu)

888
889
890
891
892
893
	go server.run(ctx)

	addr := "127.0.0.1:" + strconv.Itoa(*port)
	listener, err := net.Listen("tcp", addr)
	if err != nil {
		fmt.Println("Listen error:", err)
894
		return err
895
896
897
898
	}
	defer listener.Close()

	mux := http.NewServeMux()
899
900
901
902
903
904
905
	// TODO: support embeddings
	mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
		http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
	})

	mux.HandleFunc("POST /completion", server.completion)
	mux.HandleFunc("GET /health", server.health)
906
907
908
909
910
911
912
913

	httpServer := http.Server{
		Handler: mux,
	}

	log.Println("Server listening on", addr)
	if err := httpServer.Serve(listener); err != nil {
		log.Fatal("server error:", err)
914
		return err
915
916
	}

917
	return nil
918
}