runner.go 23.6 KB
Newer Older
Jesse Gross's avatar
Jesse Gross committed
1
package ollamarunner
2
3
4
5
6
7
8

import (
	"context"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
9
	"hash/maphash"
10
11
12
13
14
15
16
17
18
19
20
21
	"log"
	"log/slog"
	"net"
	"net/http"
	"os"
	"path/filepath"
	"regexp"
	"runtime"
	"strconv"
	"strings"
	"sync"
	"time"
22
	"unicode/utf8"
23

24
25
	"golang.org/x/sync/semaphore"

26
	"github.com/ollama/ollama/api"
27
	"github.com/ollama/ollama/llm"
28
	"github.com/ollama/ollama/ml"
Jesse Gross's avatar
Jesse Gross committed
29
	"github.com/ollama/ollama/model"
30
	"github.com/ollama/ollama/model/input"
Jesse Gross's avatar
Jesse Gross committed
31
32
33
34
	"github.com/ollama/ollama/runner/common"
	"github.com/ollama/ollama/sample"

	_ "github.com/ollama/ollama/model/models"
35
36
)

37
38
39
40
type contextList struct {
	list []ml.Context
}

41
type Sequence struct {
42
	// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
43
	// multimodal embeddings
44
	ctxs *contextList
45

46
47
48
49
	// batch index
	iBatch int

	// prompt inputs left to evaluate
50
	inputs []input.Input
51

Jesse Gross's avatar
Jesse Gross committed
52
	// inputs that have been added to a batch but not yet submitted to Forward
53
	pendingInputs []input.Input
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
	// tokens that have been generated but not returned yet (e.g. for stop sequences)
	pendingResponses []string

	// input cache being used by this sequence
	cache *InputCacheSlot

	// channel to send responses over
	responses chan string

	// channel to stop decoding (such as if the remote connection is closed)
	quit chan bool

	// number of tokens to predict
	numPredict int

70
71
	// sampler with transforms to run on generated logits
	sampler sample.Sampler
72
73
74
75
76
77
78
79

	// channel to send back the embedding if embedding only
	embedding chan []float32

	// stop sequences
	stop []string

	// number of inputs to keep at the beginning when shifting context window
Jesse Gross's avatar
Jesse Gross committed
80
	numKeep int32
81
82
83
84

	// true if an embedding are to be returned instead of text generation
	embeddingOnly bool

85
	doneReason llm.DoneReason
86
87
88
89

	// Metrics
	startProcessingTime time.Time
	startGenerationTime time.Time
Jesse Gross's avatar
Jesse Gross committed
90
	numPredicted        int
91
92
93
94
	numPromptInputs     int
}

type NewSequenceParams struct {
Jesse Gross's avatar
Jesse Gross committed
95
96
97
	numPredict int
	stop       []string
	numKeep    int32
98
	sampler    sample.Sampler
Jesse Gross's avatar
Jesse Gross committed
99
	embedding  bool
100
101
}

102
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
103
104
105
106
	s.ready.Wait()

	startTime := time.Now()

107
	inputs, ctxs, err := s.inputs(prompt, images)
108
109
110
111
112
113
114
	if err != nil {
		return nil, fmt.Errorf("failed to process inputs: %w", err)
	} else if len(inputs) == 0 {
		return nil, errors.New("no input provided")
	}

	if params.numKeep < 0 {
Jesse Gross's avatar
Jesse Gross committed
115
		params.numKeep = int32(len(inputs))
116
117
	}

118
119
120
	// Ensure that at least 1 input can be discarded during shift
	params.numKeep = min(params.numKeep, s.cache.numCtx-1)

Jesse Gross's avatar
Jesse Gross committed
121
122
	if int32(len(inputs)) > s.cache.numCtx {
		discard := int32(len(inputs)) - s.cache.numCtx
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
		promptStart := params.numKeep + discard

		// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
		sameBatch := 0
		for i, inp := range inputs {
			if sameBatch > 0 {
				sameBatch--

				if promptStart == int32(i) {
					promptStart++
				}
			} else if promptStart == int32(i) {
				break
			}

			if inp.SameBatch != 0 {
				if int32(i) < params.numKeep {
					return nil, fmt.Errorf("SameBatch may not be specified within numKeep (index: %v numKeep: %v SameBatch: %v)", i, params.numKeep, inp.SameBatch)
				}

				sameBatch = inp.SameBatch
			}
		}

		if promptStart >= int32(len(inputs)) {
			return nil, errors.New("entire prompt removed by truncation")
		}

151
		newInputs := inputs[:params.numKeep]
152
		newInputs = append(newInputs, inputs[promptStart:]...)
153
154

		slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
155
		inputs = newInputs
156
157
	}

Jesse Gross's avatar
Jesse Gross committed
158
	// TODO(jessegross): Ingest cached history for grammar
159
160

	return &Sequence{
161
		ctxs:                ctxs,
162
163
164
165
166
167
168
169
		inputs:              inputs,
		numPromptInputs:     len(inputs),
		startProcessingTime: startTime,
		numPredict:          params.numPredict,
		pendingResponses:    make([]string, 0),
		responses:           make(chan string, 100),
		quit:                make(chan bool, 1),
		embedding:           make(chan []float32, 1),
170
		sampler:             params.sampler,
171
172
173
174
175
176
177
178
		embeddingOnly:       params.embedding,
		stop:                params.stop,
		numKeep:             params.numKeep,
	}, nil
}

// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
Jesse Gross's avatar
Jesse Gross committed
179
// decoding images
180
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) {
181
	var inputs []input.Input
182
183
184
	var parts []string
	var matches [][]string

185
	multimodalProcessor, visionModel := s.model.(model.MultimodalProcessor)
186

187
188
189
190
191
192
193
194
	if visionModel {
		re := regexp.MustCompile(`\[img-(\d+)\]`)
		parts = re.Split(prompt, -1)
		matches = re.FindAllStringSubmatch(prompt, -1)
	} else {
		parts = []string{prompt}
	}

195
196
197
198
199
200
201
	var contexts contextList
	runtime.AddCleanup(&contexts, func(ctxs []ml.Context) {
		for _, ctx := range ctxs {
			ctx.Close()
		}
	}, contexts.list)

202
	postTokenize := false
203
204
	for i, part := range parts {
		// text - tokenize
205
		tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
206
		if err != nil {
207
			return nil, nil, err
208
		}
209

210
		for _, t := range tokens {
211
			inputs = append(inputs, input.Input{Token: t})
212
213
		}

Jesse Gross's avatar
Jesse Gross committed
214
		// image - decode and store
215
216
217
218
219
220
221
222
223
224
225
226
		if i < len(matches) {
			n, _ := strconv.Atoi(matches[i][1])

			imageIndex := -1
			for j := range images {
				if images[j].ID == n {
					imageIndex = j
					break
				}
			}

			if imageIndex < 0 {
227
				return nil, nil, fmt.Errorf("invalid image index: %d", n)
228
229
			}

230
231
			ctx := s.model.Backend().NewContext()
			contexts.list = append(contexts.list, ctx)
232
			imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
Jesse Gross's avatar
Jesse Gross committed
233
			if err != nil {
234
				return nil, nil, err
Jesse Gross's avatar
Jesse Gross committed
235
236
			}

237
238
239
240
			s.multimodalHash.Reset()
			_, _ = s.multimodalHash.Write(images[imageIndex].Data)
			imageHash := s.multimodalHash.Sum64()

241
			inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
242
243
244
245
246
247
			postTokenize = true
		}
	}

	if visionModel && postTokenize {
		var err error
248
		inputs, err = multimodalProcessor.PostTokenize(inputs)
249
		if err != nil {
250
			return nil, nil, err
251
252
253
		}
	}

254
	return inputs, &contexts, nil
255
256
257
}

type Server struct {
258
259
260
261
262
	// is the server ready to process requests?
	// protects access to model and image
	ready sync.WaitGroup

	// loaded model
Jesse Gross's avatar
Jesse Gross committed
263
	model model.Model
264

265
	// status for external health reporting - loading, ready to serve, etc.
266
	status llm.ServerStatus
267
268
269
270
271
272
273
274

	// current progress on loading the model
	progress float32

	// number of simultaneous requests to handle
	parallel int

	// maximum number of elements in a batch (per sequence)
275
	// TODO (jmorganca): make this n_batch
276
277
	batchSize int

278
279
280
281
282
283
284
285
	// protects access to everything below this line
	// this is context state needed for decoding
	mu sync.Mutex

	// indicates that data is ready for processing
	cond *sync.Cond

	// the list of simultaneous sequences being evaluated
286
287
	seqs []*Sequence

288
289
290
291
	// seqs can have a maximum of parallel entries, which
	// is enfoced by seqSem
	seqsSem *semaphore.Weighted

292
293
294
	// KV cache
	cache *InputCache

295
296
297
	// next sequence for prompt processing to avoid starvation
	nextSeq int

298
299
300
	// multimodalHash generates hashes for comparing equality
	// of non-text data
	multimodalHash maphash.Hash
301
302
303
304
305
306

	// vocab is a llama.cpp vocab required for gammar-based
	// constrained generation (json mode, structured outputs)
	// TODO: this is temporary until Ollama sampling supports
	// constrained generation
	vocab *sample.Vocab
307
308
309
310
311
312
313
314
315
316
317
318
}

func (s *Server) allNil() bool {
	for _, item := range s.seqs {
		if item != nil {
			return false
		}
	}
	return true
}

func flushPending(seq *Sequence) bool {
319
320
321
322
323
324
325
326
327
328
329
	joined := strings.Join(seq.pendingResponses, "")
	seq.pendingResponses = []string{}

	// Check if there are any partial UTF-8 characters remaining.
	// We already check and queue as we are generating but some may
	// still make it here:
	// - Sequence is ending, e.g. generation limit has been hit
	// - Invalid characters in the middle of a string
	// This is a stricter check to ensure we never output invalid Unicode.
	for !utf8.ValidString(joined) {
		joined = joined[:len(joined)-1]
330
331
	}

332
333
334
335
336
337
338
339
340
341
	if len(joined) == 0 {
		return true
	}

	select {
	case seq.responses <- joined:
		return true
	case <-seq.quit:
		return false
	}
342
343
}

344
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
345
346
347
348
349
350
351
352
	seq := s.seqs[seqIndex]

	flushPending(seq)
	seq.doneReason = reason
	close(seq.responses)
	close(seq.embedding)
	seq.cache.InUse = false
	s.seqs[seqIndex] = nil
353
	s.seqsSem.Release(1)
354
355
356
357
358
359
360
361
362
363
}

func (s *Server) run(ctx context.Context) {
	s.ready.Wait()

	for {
		select {
		case <-ctx.Done():
			return
		default:
Jesse Gross's avatar
Jesse Gross committed
364
			err := s.processBatch()
365
366
367
			if err != nil {
				panic(err)
			}
368
369
370
371
		}
	}
}

Jesse Gross's avatar
Jesse Gross committed
372
func (s *Server) processBatch() error {
373
374
375
376
377
378
	s.mu.Lock()
	for s.allNil() {
		s.cond.Wait() // Wait until an item is added
	}
	defer s.mu.Unlock()

379
	var batchInputs []int32
Jesse Gross's avatar
Jesse Gross committed
380
	var batch input.Batch
381

382
383
384
385
386
387
	resumeSeq := -1
	seqIdx := s.nextSeq - 1
	for range s.seqs {
		seqIdx = (seqIdx + 1) % len(s.seqs)
		seq := s.seqs[seqIdx]

388
389
390
391
392
		if seq == nil {
			continue
		}

		// if past the num predict limit
393
		if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
394
			s.removeSequence(seqIdx, llm.DoneReasonLength)
395
396
397
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
398
399
		if !s.cache.enabled {
			seq.inputs = append(seq.cache.Inputs, seq.inputs...)
400
			seq.cache.Inputs = []input.Input{}
Jesse Gross's avatar
Jesse Gross committed
401
402
		}

403
404
		batchSize := s.batchSize

405
		for i, inp := range seq.inputs {
406
407
			// If we are required to put following inputs into a single batch then extend the
			// batch size. Since we are only extending the size the minimum amount possible, this
408
			// will cause a break if we have existing inputs.
409
410
411
412
413
			minBatch := 1 + inp.SameBatch
			if minBatch > batchSize {
				batchSize = minBatch
			}

414
415
416
417
418
419
420
421
			// Stop if the required batch would put us over the total batch size (including tokens
			// added by other sequences). If we haven't been able to add anything yet then pick up
			// here again for the next batch to avoid starvation, though we can opportunistically
			// check if other sequences can still squeeze something in.
			if len(batchInputs)+minBatch > batchSize {
				if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
					resumeSeq = seqIdx
				}
422
423
				break
			}
Jesse Gross's avatar
Jesse Gross committed
424

425
426
427
428
429
430
431
432
433
434
			// If the sum of our working set (already processed tokens, tokens we added to this
			// batch, required following tokens) exceeds the context size, then trigger a shift
			// now so we don't have to do one later when we can't break the batch.
			if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
				if len(seq.pendingInputs) != 0 {
					break
				}

				err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
				if err != nil {
435
436
437
438
439
440
441
442
443
					var reprocess *ErrReprocessInputs
					if errors.As(err, &reprocess) {
						// Prepend these inputs to the sequence's inputs queue for reprocessing
						seq.inputs = append(reprocess.Inputs, seq.inputs...)
						// Skip this sequence but continue processing the rest
						continue
					} else {
						return err
					}
444
445
446
				}
			}

447
			batchInputs = append(batchInputs, inp.Token)
448
			if inp.Multimodal != nil {
449
				batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
450
451
			}

Jesse Gross's avatar
Jesse Gross committed
452
453
			batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
			batch.Sequences = append(batch.Sequences, seq.cache.Id)
Jesse Gross's avatar
Jesse Gross committed
454

Jesse Gross's avatar
Jesse Gross committed
455
			seq.iBatch = len(batch.Outputs)
456
			if i+1 == len(seq.inputs) {
457
				batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
Jesse Gross's avatar
Jesse Gross committed
458
			}
459
			seq.pendingInputs = append(seq.pendingInputs, inp)
460
		}
461
462

		seq.inputs = seq.inputs[len(seq.pendingInputs):]
463
464
	}

465
466
467
468
469
470
	if resumeSeq != -1 {
		s.nextSeq = resumeSeq
	} else {
		s.nextSeq = seqIdx + 1
	}

471
	if len(batchInputs) == 0 {
472
		return nil
473
474
	}

Jesse Gross's avatar
Jesse Gross committed
475
476
	ctx := s.model.Backend().NewContext()
	defer ctx.Close()
477

478
	modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
479
	if err != nil {
480
		return fmt.Errorf("failed to decode batch: %w", err)
481
482
	}

483
	logits := modelOutput.Floats()
484

485
486
487
488
489
	for i, seq := range s.seqs {
		if seq == nil {
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
490
		// After calling Forward, pending inputs are now in the cache
491
492
		if len(seq.pendingInputs) > 0 {
			seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
493
			seq.pendingInputs = []input.Input{}
494
495
		}

496
497
		// don't sample prompt processing
		if len(seq.inputs) != 0 {
Jesse Gross's avatar
Jesse Gross committed
498
499
500
			if !s.cache.enabled {
				return errors.New("caching disabled but unable to fit entire input in a batch")
			}
501
502
503
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
504
505
		seq.numPredicted++
		if seq.numPredicted == 1 {
506
507
508
509
510
			seq.startGenerationTime = time.Now()
		}

		// if done processing the prompt, generate an embedding and return
		if seq.embeddingOnly {
Jesse Gross's avatar
Jesse Gross committed
511
			// TODO(jessegross): Embedding support
512
			slog.Warn("generation of embedding outputs not yet supported")
513
			s.removeSequence(i, llm.DoneReasonStop)
514
			continue
515
516
517
		}

		// sample a token
Jesse Gross's avatar
Jesse Gross committed
518
		vocabSize := len(logits) / len(batch.Outputs)
519
520

		token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
Jesse Gross's avatar
Jesse Gross committed
521
		if err != nil {
522
			return fmt.Errorf("failed to sample token: %w", err)
Jesse Gross's avatar
Jesse Gross committed
523
		}
524
525

		// if it's an end of sequence token, break
Jesse Gross's avatar
Jesse Gross committed
526
		if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
527
528
529
530
			// TODO (jmorganca): we should send this back
			// as it's important for the /api/generate context
			// seq.responses <- piece

531
			s.removeSequence(i, llm.DoneReasonStop)
532
533
534
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
535
536
537
538
539
		piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
		if err != nil {
			return err
		}

540
		seq.inputs = []input.Input{{Token: token}}
541
542
543
544

		seq.pendingResponses = append(seq.pendingResponses, piece)
		sequence := strings.Join(seq.pendingResponses, "")

Jesse Gross's avatar
Jesse Gross committed
545
		if ok, stop := common.FindStop(sequence, seq.stop); ok {
546
547
548
549
			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)

			var tokenTruncated bool
			origLen := len(seq.pendingResponses)
Jesse Gross's avatar
Jesse Gross committed
550
			seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop)
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
			newLen := len(seq.pendingResponses)

			// Update the cache based on the tokens that will be returned:
			// - We have 1 token more than is currently in the cache because
			// the last one generated wasn't submitted to Decode
			// - Remove any stop sequences that we stripped out
			// - If truncateStop removed a portion of a token, drop that
			// - As defense-in-depth, if truncatedToken didn't find a stop token
			// remove the extra one that we added to the cache len
			tokenLen := len(seq.cache.Inputs) + 1
			tokenLen -= origLen - newLen
			if tokenTruncated || origLen == newLen {
				tokenLen--
			}
			seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
566

567
			s.removeSequence(i, llm.DoneReasonStop)
568
569
570
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
571
		if common.ContainsStopSuffix(sequence, seq.stop) {
572
573
574
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
575
		if common.IncompleteUnicode(sequence) {
576
577
578
579
			continue
		}

		if !flushPending(seq) {
580
			s.removeSequence(i, llm.DoneReasonConnectionClosed)
581
582
		}
	}
583
584

	return nil
585
586
587
}

func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
588
	var req llm.CompletionRequest
589
590
591
592
593
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, "Bad request", http.StatusBadRequest)
		return
	}

594
595
596
597
598
	if req.Options == nil {
		opts := api.DefaultOptions()
		req.Options = &opts
	}

599
600
601
602
603
604
605
606
607
608
	// Set the headers to indicate streaming
	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Transfer-Encoding", "chunked")

	flusher, ok := w.(http.Flusher)
	if !ok {
		http.Error(w, "Streaming not supported", http.StatusInternalServerError)
		return
	}

609
610
611
612
613
614
615
616
617
618
	var grammar *sample.Grammar
	var err error
	if req.Grammar != "" {
		grammar, err = sample.NewGrammar(s.vocab, req.Grammar)
		if err != nil {
			http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
			return
		}
	}

619
	sampler := sample.NewSampler(
620
621
622
623
624
		req.Options.Temperature,
		req.Options.TopK,
		req.Options.TopP,
		req.Options.MinP,
		req.Options.Seed,
625
		grammar,
626
627
	)

628
	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
629
630
631
		numPredict: req.Options.NumPredict,
		stop:       req.Options.Stop,
		numKeep:    int32(req.Options.NumKeep),
632
		sampler:    sampler,
Jesse Gross's avatar
Jesse Gross committed
633
		embedding:  false,
634
635
636
637
638
639
	})
	if err != nil {
		http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
		return
	}

640
	// Ensure there is a place to put the sequence, released when removed from s.seqs
641
	if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
642
643
644
		if errors.Is(err, context.Canceled) {
			slog.Info("aborting completion request due to client closing the connection")
		} else {
645
			http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
646
		}
647
648
649
		return
	}

650
	s.mu.Lock()
651
	found := false
652
653
	for i, sq := range s.seqs {
		if sq == nil {
654
			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
655
656
			if err != nil {
				s.mu.Unlock()
657
				s.seqsSem.Release(1)
658
659
660
				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
				return
			}
661

662
663
			s.seqs[i] = seq
			s.cond.Signal()
664
			found = true
665
666
667
668
669
			break
		}
	}
	s.mu.Unlock()

670
	if !found {
671
		s.seqsSem.Release(1)
672
673
674
675
		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
		return
	}

676
677
678
679
680
681
682
	for {
		select {
		case <-r.Context().Done():
			close(seq.quit)
			return
		case content, ok := <-seq.responses:
			if ok {
683
				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
684
685
686
687
688
689
690
691
692
					Content: content,
				}); err != nil {
					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
					close(seq.quit)
					return
				}

				flusher.Flush()
			} else {
693
694
				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
					Done:               true,
695
					DoneReason:         seq.doneReason,
696
697
698
699
					PromptEvalCount:    seq.numPromptInputs,
					PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
					EvalCount:          seq.numPredicted,
					EvalDuration:       time.Since(seq.startGenerationTime),
700
701
702
703
704
705
706
707
708
709
710
711
				}); err != nil {
					http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
				}

				return
			}
		}
	}
}

func (s *Server) health(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "application/json")
712
713
	if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
		Status:   s.status,
714
715
716
717
718
719
		Progress: s.progress,
	}); err != nil {
		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
	}
}

720
721
722
723
724
725
726
727
728
729
730
type multiLPath []string

func (m *multiLPath) Set(value string) error {
	*m = append(*m, value)
	return nil
}

func (m *multiLPath) String() string {
	return strings.Join(*m, ", ")
}

731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
func (s *Server) reserveWorstCaseGraph() error {
	ctx := s.model.Backend().NewContext()
	defer ctx.Close()

	var batch input.Batch

	inputs := make([]int32, s.batchSize)
	batch.Positions = make([]int32, len(inputs))
	batch.Sequences = make([]int, len(inputs))
	for i := range inputs {
		batch.Positions[i] = int32(i)
	}

	batch.Outputs = make([]int32, s.parallel)
	for i := range batch.Outputs {
		batch.Outputs[i] = int32(i)
	}

	var err error
	batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
	if err != nil {
		return err
	}

	cache := s.model.Config().Cache
	if cache != nil {
		err := cache.StartForward(ctx, batch, true)
		if err != nil {
			return err
		}
	}

	t, err := s.model.Forward(ctx, batch)
	if err != nil {
		return err
	}

	err = ctx.Forward(t).Reserve()
	if err != nil {
		return err
	}

	return nil
}

776
func (s *Server) loadModel(
777
	ctx context.Context,
778
	mpath string,
779
	params ml.BackendParams,
780
	lpath multiLPath,
Jesse Gross's avatar
Jesse Gross committed
781
	parallel int,
782
	kvCacheType string,
Jesse Gross's avatar
Jesse Gross committed
783
	kvSize int,
784
785
	multiUserCache bool,
) {
786
	var err error
787
	s.model, err = model.New(ctx, mpath, params)
788
789
790
	if err != nil {
		panic(err)
	}
791

792
793
	s.vocab = sample.NewVocab(mpath)

Jesse Gross's avatar
Jesse Gross committed
794
	// TODO(jessegross): LoRA loading
795
	if lpath.String() != "" {
Jesse Gross's avatar
Jesse Gross committed
796
		panic("loras are not yet implemented")
797
798
	}

799
	s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
800
801
802
	if err != nil {
		panic(err)
	}
803

Jesse Gross's avatar
Jesse Gross committed
804
805
806
807
808
809
810
811
812
	if !s.cache.enabled && parallel > 1 {
		parallel = 1
		slog.Warn("model does not support caching, disabling parallel processing")
	}

	s.parallel = parallel
	s.seqs = make([]*Sequence, s.parallel)
	s.seqsSem = semaphore.NewWeighted(int64(s.parallel))

813
814
815
816
817
	err = s.reserveWorstCaseGraph()
	if err != nil {
		panic(err)
	}

818
	s.status = llm.ServerStatusReady
819
820
821
	s.ready.Done()
}

822
823
824
825
826
func Execute(args []string) error {
	fs := flag.NewFlagSet("runner", flag.ExitOnError)
	mpath := fs.String("model", "", "Path to model binary file")
	parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously")
	batchSize := fs.Int("batch-size", 512, "Batch size")
827
828
	numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
	mainGPU := fs.Int("main-gpu", 0, "Main GPU")
829
	flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
830
831
832
	kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
	kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
	port := fs.Int("port", 8080, "Port to expose the server on")
833
	threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
834
	verbose := fs.Bool("verbose", false, "verbose output (default: disabled)")
Jesse Gross's avatar
Jesse Gross committed
835
836
	_ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
	_ = fs.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing")
837
	tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
838
	multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
839

840
	var lpaths multiLPath
841
	fs.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
842

843
844
845
846
847
848
	fs.Usage = func() {
		fmt.Fprintf(fs.Output(), "Runner usage\n")
		fs.PrintDefaults()
	}
	if err := fs.Parse(args); err != nil {
		return err
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
	}
	level := slog.LevelInfo
	if *verbose {
		level = slog.LevelDebug
	}
	handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
		Level:     level,
		AddSource: true,
		ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
			if attr.Key == slog.SourceKey {
				source := attr.Value.Any().(*slog.Source)
				source.File = filepath.Base(source.File)
			}
			return attr
		},
	})
	slog.SetDefault(slog.New(handler))
Jesse Gross's avatar
Jesse Gross committed
866
	slog.Info("starting ollama engine")
867
868
869

	server := &Server{
		batchSize: *batchSize,
870
		status:    llm.ServerStatusLoadingModel,
871
872
	}

Jesse Gross's avatar
Jesse Gross committed
873
874
875
876
	// TODO(jessegross): Parameters that need to be implemented:
	//	no-mmap
	//	mlock

877
	var tensorSplitFloats []float32
878
	if *tensorSplit != "" {
879
880
881
		splits := strings.Split(*tensorSplit, ",")
		tensorSplitFloats = make([]float32, len(splits))
		for i, s := range splits {
882
			f, _ := strconv.ParseFloat(s, 32)
883
			tensorSplitFloats[i] = float32(f)
884
		}
885
886
887
	}

	params := ml.BackendParams{
888
889
890
		Progress: func(progress float32) {
			server.progress = progress
		},
891
892
893
894
895
		NumThreads:     *threads,
		NumGPULayers:   *numGPULayers,
		MainGPU:        *mainGPU,
		TensorSplit:    tensorSplitFloats,
		FlashAttention: *flashAttention,
896
	}
897
898
899

	server.ready.Add(1)
	ctx, cancel := context.WithCancel(context.Background())
Michael Yang's avatar
Michael Yang committed
900
901
	defer cancel()

902
903
904
905
	go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)

	server.cond = sync.NewCond(&server.mu)

906
907
908
909
910
911
	go server.run(ctx)

	addr := "127.0.0.1:" + strconv.Itoa(*port)
	listener, err := net.Listen("tcp", addr)
	if err != nil {
		fmt.Println("Listen error:", err)
912
		return err
913
914
915
916
	}
	defer listener.Close()

	mux := http.NewServeMux()
917
918
919
920
921
922
923
	// TODO: support embeddings
	mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
		http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
	})

	mux.HandleFunc("POST /completion", server.completion)
	mux.HandleFunc("GET /health", server.health)
924
925
926
927
928
929
930
931

	httpServer := http.Server{
		Handler: mux,
	}

	log.Println("Server listening on", addr)
	if err := httpServer.Serve(listener); err != nil {
		log.Fatal("server error:", err)
932
		return err
933
934
	}

935
	return nil
936
}