runner.go 23.5 KB
Newer Older
Jesse Gross's avatar
Jesse Gross committed
1
package ollamarunner
2
3
4
5
6
7
8

import (
	"context"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
9
	"hash/maphash"
10
11
12
13
14
15
16
17
18
19
20
21
	"log"
	"log/slog"
	"net"
	"net/http"
	"os"
	"path/filepath"
	"regexp"
	"runtime"
	"strconv"
	"strings"
	"sync"
	"time"
22
	"unicode/utf8"
23

24
25
	"golang.org/x/sync/semaphore"

26
	"github.com/ollama/ollama/api"
27
	"github.com/ollama/ollama/llm"
28
	"github.com/ollama/ollama/ml"
Jesse Gross's avatar
Jesse Gross committed
29
	"github.com/ollama/ollama/model"
30
	"github.com/ollama/ollama/model/input"
Jesse Gross's avatar
Jesse Gross committed
31
32
33
34
	"github.com/ollama/ollama/runner/common"
	"github.com/ollama/ollama/sample"

	_ "github.com/ollama/ollama/model/models"
35
36
)

37
38
39
40
type contextList struct {
	list []ml.Context
}

41
type Sequence struct {
42
	// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
43
	// multimodal embeddings
44
	ctxs *contextList
45

46
47
48
49
	// batch index
	iBatch int

	// prompt inputs left to evaluate
50
	inputs []input.Input
51

Jesse Gross's avatar
Jesse Gross committed
52
	// inputs that have been added to a batch but not yet submitted to Forward
53
	pendingInputs []input.Input
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
	// tokens that have been generated but not returned yet (e.g. for stop sequences)
	pendingResponses []string

	// input cache being used by this sequence
	cache *InputCacheSlot

	// channel to send responses over
	responses chan string

	// channel to stop decoding (such as if the remote connection is closed)
	quit chan bool

	// number of tokens to predict
	numPredict int

70
71
	// sampler with transforms to run on generated logits
	sampler sample.Sampler
72
73
74
75
76
77
78
79

	// channel to send back the embedding if embedding only
	embedding chan []float32

	// stop sequences
	stop []string

	// number of inputs to keep at the beginning when shifting context window
Jesse Gross's avatar
Jesse Gross committed
80
	numKeep int32
81
82
83
84

	// true if an embedding are to be returned instead of text generation
	embeddingOnly bool

85
	doneReason llm.DoneReason
86
87
88
89

	// Metrics
	startProcessingTime time.Time
	startGenerationTime time.Time
Jesse Gross's avatar
Jesse Gross committed
90
	numPredicted        int
91
92
93
94
	numPromptInputs     int
}

type NewSequenceParams struct {
Jesse Gross's avatar
Jesse Gross committed
95
96
97
	numPredict int
	stop       []string
	numKeep    int32
98
	sampler    sample.Sampler
Jesse Gross's avatar
Jesse Gross committed
99
	embedding  bool
100
101
}

102
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
103
104
105
106
	s.ready.Wait()

	startTime := time.Now()

107
	inputs, ctxs, err := s.inputs(prompt, images)
108
109
110
111
112
113
114
	if err != nil {
		return nil, fmt.Errorf("failed to process inputs: %w", err)
	} else if len(inputs) == 0 {
		return nil, errors.New("no input provided")
	}

	if params.numKeep < 0 {
Jesse Gross's avatar
Jesse Gross committed
115
		params.numKeep = int32(len(inputs))
116
117
	}

118
119
120
	// Ensure that at least 1 input can be discarded during shift
	params.numKeep = min(params.numKeep, s.cache.numCtx-1)

Jesse Gross's avatar
Jesse Gross committed
121
122
	if int32(len(inputs)) > s.cache.numCtx {
		discard := int32(len(inputs)) - s.cache.numCtx
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
		promptStart := params.numKeep + discard

		// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
		sameBatch := 0
		for i, inp := range inputs {
			if sameBatch > 0 {
				sameBatch--

				if promptStart == int32(i) {
					promptStart++
				}
			} else if promptStart == int32(i) {
				break
			}

			if inp.SameBatch != 0 {
				if int32(i) < params.numKeep {
					return nil, fmt.Errorf("SameBatch may not be specified within numKeep (index: %v numKeep: %v SameBatch: %v)", i, params.numKeep, inp.SameBatch)
				}

				sameBatch = inp.SameBatch
			}
		}

		if promptStart >= int32(len(inputs)) {
			return nil, errors.New("entire prompt removed by truncation")
		}

151
		newInputs := inputs[:params.numKeep]
152
		newInputs = append(newInputs, inputs[promptStart:]...)
153
154

		slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
155
		inputs = newInputs
156
157
	}

Jesse Gross's avatar
Jesse Gross committed
158
	// TODO(jessegross): Ingest cached history for grammar
159
160

	return &Sequence{
161
		ctxs:                ctxs,
162
163
164
165
166
167
168
169
		inputs:              inputs,
		numPromptInputs:     len(inputs),
		startProcessingTime: startTime,
		numPredict:          params.numPredict,
		pendingResponses:    make([]string, 0),
		responses:           make(chan string, 100),
		quit:                make(chan bool, 1),
		embedding:           make(chan []float32, 1),
170
		sampler:             params.sampler,
171
172
173
174
175
176
177
178
		embeddingOnly:       params.embedding,
		stop:                params.stop,
		numKeep:             params.numKeep,
	}, nil
}

// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
Jesse Gross's avatar
Jesse Gross committed
179
// decoding images
180
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) {
181
	var inputs []input.Input
182
183
184
	var parts []string
	var matches [][]string

185
	multimodalProcessor, visionModel := s.model.(model.MultimodalProcessor)
186

187
188
189
190
191
192
193
194
	if visionModel {
		re := regexp.MustCompile(`\[img-(\d+)\]`)
		parts = re.Split(prompt, -1)
		matches = re.FindAllStringSubmatch(prompt, -1)
	} else {
		parts = []string{prompt}
	}

195
196
197
198
199
200
201
	var contexts contextList
	runtime.AddCleanup(&contexts, func(ctxs []ml.Context) {
		for _, ctx := range ctxs {
			ctx.Close()
		}
	}, contexts.list)

202
	postTokenize := false
203
204
	for i, part := range parts {
		// text - tokenize
205
		tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
206
		if err != nil {
207
			return nil, nil, err
208
		}
209

210
		for _, t := range tokens {
211
			inputs = append(inputs, input.Input{Token: t})
212
213
		}

Jesse Gross's avatar
Jesse Gross committed
214
		// image - decode and store
215
216
217
218
219
220
221
222
223
224
225
226
		if i < len(matches) {
			n, _ := strconv.Atoi(matches[i][1])

			imageIndex := -1
			for j := range images {
				if images[j].ID == n {
					imageIndex = j
					break
				}
			}

			if imageIndex < 0 {
227
				return nil, nil, fmt.Errorf("invalid image index: %d", n)
228
229
			}

230
231
			ctx := s.model.Backend().NewContext()
			contexts.list = append(contexts.list, ctx)
232
			imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
Jesse Gross's avatar
Jesse Gross committed
233
			if err != nil {
234
				return nil, nil, err
Jesse Gross's avatar
Jesse Gross committed
235
236
			}

237
238
239
240
			s.multimodalHash.Reset()
			_, _ = s.multimodalHash.Write(images[imageIndex].Data)
			imageHash := s.multimodalHash.Sum64()

241
			inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
242
243
244
245
246
247
			postTokenize = true
		}
	}

	if visionModel && postTokenize {
		var err error
248
		inputs, err = multimodalProcessor.PostTokenize(inputs)
249
		if err != nil {
250
			return nil, nil, err
251
252
253
		}
	}

254
	return inputs, &contexts, nil
255
256
257
}

type Server struct {
258
259
260
261
262
	// is the server ready to process requests?
	// protects access to model and image
	ready sync.WaitGroup

	// loaded model
Jesse Gross's avatar
Jesse Gross committed
263
	model model.Model
264

265
	// status for external health reporting - loading, ready to serve, etc.
266
	status llm.ServerStatus
267
268
269
270
271
272
273
274

	// current progress on loading the model
	progress float32

	// number of simultaneous requests to handle
	parallel int

	// maximum number of elements in a batch (per sequence)
275
	// TODO (jmorganca): make this n_batch
276
277
	batchSize int

278
279
280
281
282
283
284
285
	// protects access to everything below this line
	// this is context state needed for decoding
	mu sync.Mutex

	// indicates that data is ready for processing
	cond *sync.Cond

	// the list of simultaneous sequences being evaluated
286
287
	seqs []*Sequence

288
289
290
291
	// seqs can have a maximum of parallel entries, which
	// is enfoced by seqSem
	seqsSem *semaphore.Weighted

292
293
294
	// KV cache
	cache *InputCache

295
296
297
	// next sequence for prompt processing to avoid starvation
	nextSeq int

298
299
300
	// multimodalHash generates hashes for comparing equality
	// of non-text data
	multimodalHash maphash.Hash
301
302
303
304
305
306
307
308
309
310
311
312
}

func (s *Server) allNil() bool {
	for _, item := range s.seqs {
		if item != nil {
			return false
		}
	}
	return true
}

func flushPending(seq *Sequence) bool {
313
314
315
316
317
318
319
320
321
322
323
	joined := strings.Join(seq.pendingResponses, "")
	seq.pendingResponses = []string{}

	// Check if there are any partial UTF-8 characters remaining.
	// We already check and queue as we are generating but some may
	// still make it here:
	// - Sequence is ending, e.g. generation limit has been hit
	// - Invalid characters in the middle of a string
	// This is a stricter check to ensure we never output invalid Unicode.
	for !utf8.ValidString(joined) {
		joined = joined[:len(joined)-1]
324
325
	}

326
327
328
329
330
331
332
333
334
335
	if len(joined) == 0 {
		return true
	}

	select {
	case seq.responses <- joined:
		return true
	case <-seq.quit:
		return false
	}
336
337
}

338
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
339
340
341
342
343
344
345
346
	seq := s.seqs[seqIndex]

	flushPending(seq)
	seq.doneReason = reason
	close(seq.responses)
	close(seq.embedding)
	seq.cache.InUse = false
	s.seqs[seqIndex] = nil
347
	s.seqsSem.Release(1)
348
349
350
351
352
353
354
355
356
357
}

func (s *Server) run(ctx context.Context) {
	s.ready.Wait()

	for {
		select {
		case <-ctx.Done():
			return
		default:
Jesse Gross's avatar
Jesse Gross committed
358
			err := s.processBatch()
359
360
361
			if err != nil {
				panic(err)
			}
362
363
364
365
		}
	}
}

Jesse Gross's avatar
Jesse Gross committed
366
func (s *Server) processBatch() error {
367
368
369
370
371
372
	s.mu.Lock()
	for s.allNil() {
		s.cond.Wait() // Wait until an item is added
	}
	defer s.mu.Unlock()

373
	var batchInputs []int32
Jesse Gross's avatar
Jesse Gross committed
374
	var batch input.Batch
375

376
377
378
379
380
381
	resumeSeq := -1
	seqIdx := s.nextSeq - 1
	for range s.seqs {
		seqIdx = (seqIdx + 1) % len(s.seqs)
		seq := s.seqs[seqIdx]

382
383
384
385
386
		if seq == nil {
			continue
		}

		// if past the num predict limit
387
		if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
388
			s.removeSequence(seqIdx, llm.DoneReasonLength)
389
390
391
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
392
393
		if !s.cache.enabled {
			seq.inputs = append(seq.cache.Inputs, seq.inputs...)
394
			seq.cache.Inputs = []input.Input{}
Jesse Gross's avatar
Jesse Gross committed
395
396
		}

397
398
		batchSize := s.batchSize

399
		for i, inp := range seq.inputs {
400
401
			// If we are required to put following inputs into a single batch then extend the
			// batch size. Since we are only extending the size the minimum amount possible, this
402
			// will cause a break if we have existing inputs.
403
404
405
406
407
			minBatch := 1 + inp.SameBatch
			if minBatch > batchSize {
				batchSize = minBatch
			}

408
409
410
411
412
413
414
415
			// Stop if the required batch would put us over the total batch size (including tokens
			// added by other sequences). If we haven't been able to add anything yet then pick up
			// here again for the next batch to avoid starvation, though we can opportunistically
			// check if other sequences can still squeeze something in.
			if len(batchInputs)+minBatch > batchSize {
				if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
					resumeSeq = seqIdx
				}
416
417
				break
			}
Jesse Gross's avatar
Jesse Gross committed
418

419
420
421
422
423
424
425
426
427
428
			// If the sum of our working set (already processed tokens, tokens we added to this
			// batch, required following tokens) exceeds the context size, then trigger a shift
			// now so we don't have to do one later when we can't break the batch.
			if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
				if len(seq.pendingInputs) != 0 {
					break
				}

				err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
				if err != nil {
429
430
431
432
433
434
435
436
437
					var reprocess *ErrReprocessInputs
					if errors.As(err, &reprocess) {
						// Prepend these inputs to the sequence's inputs queue for reprocessing
						seq.inputs = append(reprocess.Inputs, seq.inputs...)
						// Skip this sequence but continue processing the rest
						continue
					} else {
						return err
					}
438
439
440
				}
			}

441
			batchInputs = append(batchInputs, inp.Token)
442
			if inp.Multimodal != nil {
443
				batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
444
445
			}

Jesse Gross's avatar
Jesse Gross committed
446
447
			batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
			batch.Sequences = append(batch.Sequences, seq.cache.Id)
Jesse Gross's avatar
Jesse Gross committed
448

Jesse Gross's avatar
Jesse Gross committed
449
			seq.iBatch = len(batch.Outputs)
450
			if i+1 == len(seq.inputs) {
451
				batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
Jesse Gross's avatar
Jesse Gross committed
452
			}
453
			seq.pendingInputs = append(seq.pendingInputs, inp)
454
		}
455
456

		seq.inputs = seq.inputs[len(seq.pendingInputs):]
457
458
	}

459
460
461
462
463
464
	if resumeSeq != -1 {
		s.nextSeq = resumeSeq
	} else {
		s.nextSeq = seqIdx + 1
	}

465
	if len(batchInputs) == 0 {
466
		return nil
467
468
	}

Jesse Gross's avatar
Jesse Gross committed
469
470
	ctx := s.model.Backend().NewContext()
	defer ctx.Close()
471

472
	modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
473
	if err != nil {
474
		return fmt.Errorf("failed to decode batch: %w", err)
475
476
	}

477
	logits := modelOutput.Floats()
478

479
480
481
482
483
	for i, seq := range s.seqs {
		if seq == nil {
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
484
		// After calling Forward, pending inputs are now in the cache
485
486
		if len(seq.pendingInputs) > 0 {
			seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
487
			seq.pendingInputs = []input.Input{}
488
489
		}

490
491
		// don't sample prompt processing
		if len(seq.inputs) != 0 {
Jesse Gross's avatar
Jesse Gross committed
492
493
494
			if !s.cache.enabled {
				return errors.New("caching disabled but unable to fit entire input in a batch")
			}
495
496
497
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
498
499
		seq.numPredicted++
		if seq.numPredicted == 1 {
500
501
502
503
504
			seq.startGenerationTime = time.Now()
		}

		// if done processing the prompt, generate an embedding and return
		if seq.embeddingOnly {
Jesse Gross's avatar
Jesse Gross committed
505
			// TODO(jessegross): Embedding support
506
			slog.Warn("generation of embedding outputs not yet supported")
507
			s.removeSequence(i, llm.DoneReasonStop)
508
			continue
509
510
511
		}

		// sample a token
Jesse Gross's avatar
Jesse Gross committed
512
		vocabSize := len(logits) / len(batch.Outputs)
513
514

		token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
Jesse Gross's avatar
Jesse Gross committed
515
		if err != nil {
516
			return fmt.Errorf("failed to sample token: %w", err)
Jesse Gross's avatar
Jesse Gross committed
517
		}
518
519

		// if it's an end of sequence token, break
Jesse Gross's avatar
Jesse Gross committed
520
		if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
521
522
523
524
			// TODO (jmorganca): we should send this back
			// as it's important for the /api/generate context
			// seq.responses <- piece

525
			s.removeSequence(i, llm.DoneReasonStop)
526
527
528
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
529
530
531
532
533
		piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
		if err != nil {
			return err
		}

534
		seq.inputs = []input.Input{{Token: token}}
535
536
537
538

		seq.pendingResponses = append(seq.pendingResponses, piece)
		sequence := strings.Join(seq.pendingResponses, "")

Jesse Gross's avatar
Jesse Gross committed
539
		if ok, stop := common.FindStop(sequence, seq.stop); ok {
540
541
542
543
			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)

			var tokenTruncated bool
			origLen := len(seq.pendingResponses)
Jesse Gross's avatar
Jesse Gross committed
544
			seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop)
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
			newLen := len(seq.pendingResponses)

			// Update the cache based on the tokens that will be returned:
			// - We have 1 token more than is currently in the cache because
			// the last one generated wasn't submitted to Decode
			// - Remove any stop sequences that we stripped out
			// - If truncateStop removed a portion of a token, drop that
			// - As defense-in-depth, if truncatedToken didn't find a stop token
			// remove the extra one that we added to the cache len
			tokenLen := len(seq.cache.Inputs) + 1
			tokenLen -= origLen - newLen
			if tokenTruncated || origLen == newLen {
				tokenLen--
			}
			seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
560

561
			s.removeSequence(i, llm.DoneReasonStop)
562
563
564
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
565
		if common.ContainsStopSuffix(sequence, seq.stop) {
566
567
568
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
569
		if common.IncompleteUnicode(sequence) {
570
571
572
573
			continue
		}

		if !flushPending(seq) {
574
			s.removeSequence(i, llm.DoneReasonConnectionClosed)
575
576
		}
	}
577
578

	return nil
579
580
581
}

func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
582
	var req llm.CompletionRequest
583
584
585
586
587
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, "Bad request", http.StatusBadRequest)
		return
	}

588
589
590
591
592
	if req.Options == nil {
		opts := api.DefaultOptions()
		req.Options = &opts
	}

593
594
595
596
597
598
599
600
601
602
	// Set the headers to indicate streaming
	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Transfer-Encoding", "chunked")

	flusher, ok := w.(http.Flusher)
	if !ok {
		http.Error(w, "Streaming not supported", http.StatusInternalServerError)
		return
	}

603
	var grammar *sample.GrammarSampler
604
605
	var err error
	if req.Grammar != "" {
606
		grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
607
608
609
610
		if err != nil {
			http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
			return
		}
611
		defer grammar.Free()
612
613
	}

614
	sampler := sample.NewSampler(
615
616
617
618
619
		req.Options.Temperature,
		req.Options.TopK,
		req.Options.TopP,
		req.Options.MinP,
		req.Options.Seed,
620
		grammar,
621
622
	)

623
	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
624
625
626
		numPredict: req.Options.NumPredict,
		stop:       req.Options.Stop,
		numKeep:    int32(req.Options.NumKeep),
627
		sampler:    sampler,
Jesse Gross's avatar
Jesse Gross committed
628
		embedding:  false,
629
630
631
632
633
634
	})
	if err != nil {
		http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
		return
	}

635
	// Ensure there is a place to put the sequence, released when removed from s.seqs
636
	if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
637
638
639
		if errors.Is(err, context.Canceled) {
			slog.Info("aborting completion request due to client closing the connection")
		} else {
640
			http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
641
		}
642
643
644
		return
	}

645
	s.mu.Lock()
646
	found := false
647
648
	for i, sq := range s.seqs {
		if sq == nil {
649
			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
650
651
			if err != nil {
				s.mu.Unlock()
652
				s.seqsSem.Release(1)
653
654
655
				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
				return
			}
656

657
658
			s.seqs[i] = seq
			s.cond.Signal()
659
			found = true
660
661
662
663
664
			break
		}
	}
	s.mu.Unlock()

665
	if !found {
666
		s.seqsSem.Release(1)
667
668
669
670
		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
		return
	}

671
672
673
674
675
676
677
	for {
		select {
		case <-r.Context().Done():
			close(seq.quit)
			return
		case content, ok := <-seq.responses:
			if ok {
678
				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
679
680
681
682
683
684
685
686
687
					Content: content,
				}); err != nil {
					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
					close(seq.quit)
					return
				}

				flusher.Flush()
			} else {
688
689
				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
					Done:               true,
690
					DoneReason:         seq.doneReason,
691
692
693
694
					PromptEvalCount:    seq.numPromptInputs,
					PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
					EvalCount:          seq.numPredicted,
					EvalDuration:       time.Since(seq.startGenerationTime),
695
696
697
698
699
700
701
702
703
704
705
706
				}); err != nil {
					http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
				}

				return
			}
		}
	}
}

func (s *Server) health(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "application/json")
707
708
	if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
		Status:   s.status,
709
710
711
712
713
714
		Progress: s.progress,
	}); err != nil {
		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
	}
}

715
716
717
718
719
720
721
722
723
724
725
type multiLPath []string

func (m *multiLPath) Set(value string) error {
	*m = append(*m, value)
	return nil
}

func (m *multiLPath) String() string {
	return strings.Join(*m, ", ")
}

726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
func (s *Server) reserveWorstCaseGraph() error {
	ctx := s.model.Backend().NewContext()
	defer ctx.Close()

	var batch input.Batch

	inputs := make([]int32, s.batchSize)
	batch.Positions = make([]int32, len(inputs))
	batch.Sequences = make([]int, len(inputs))
	for i := range inputs {
		batch.Positions[i] = int32(i)
	}

	batch.Outputs = make([]int32, s.parallel)
	for i := range batch.Outputs {
		batch.Outputs[i] = int32(i)
	}

	var err error
	batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
	if err != nil {
		return err
	}

	cache := s.model.Config().Cache
	if cache != nil {
		err := cache.StartForward(ctx, batch, true)
		if err != nil {
			return err
		}
	}

	t, err := s.model.Forward(ctx, batch)
	if err != nil {
		return err
	}

	err = ctx.Forward(t).Reserve()
	if err != nil {
		return err
	}

	return nil
}

771
func (s *Server) loadModel(
772
	ctx context.Context,
773
	mpath string,
774
	params ml.BackendParams,
775
	lpath multiLPath,
Jesse Gross's avatar
Jesse Gross committed
776
	parallel int,
777
	kvCacheType string,
Jesse Gross's avatar
Jesse Gross committed
778
	kvSize int,
779
780
	multiUserCache bool,
) {
781
	var err error
782
	s.model, err = model.New(ctx, mpath, params)
783
784
785
	if err != nil {
		panic(err)
	}
786

Jesse Gross's avatar
Jesse Gross committed
787
	// TODO(jessegross): LoRA loading
788
	if lpath.String() != "" {
Jesse Gross's avatar
Jesse Gross committed
789
		panic("loras are not yet implemented")
790
791
	}

792
	s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
793
794
795
	if err != nil {
		panic(err)
	}
796

Jesse Gross's avatar
Jesse Gross committed
797
798
799
800
801
802
803
804
805
	if !s.cache.enabled && parallel > 1 {
		parallel = 1
		slog.Warn("model does not support caching, disabling parallel processing")
	}

	s.parallel = parallel
	s.seqs = make([]*Sequence, s.parallel)
	s.seqsSem = semaphore.NewWeighted(int64(s.parallel))

806
807
808
809
810
	err = s.reserveWorstCaseGraph()
	if err != nil {
		panic(err)
	}

811
	s.status = llm.ServerStatusReady
812
813
814
	s.ready.Done()
}

815
816
817
818
819
func Execute(args []string) error {
	fs := flag.NewFlagSet("runner", flag.ExitOnError)
	mpath := fs.String("model", "", "Path to model binary file")
	parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously")
	batchSize := fs.Int("batch-size", 512, "Batch size")
820
821
	numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
	mainGPU := fs.Int("main-gpu", 0, "Main GPU")
822
	flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
823
824
825
	kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
	kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
	port := fs.Int("port", 8080, "Port to expose the server on")
826
	threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
827
	verbose := fs.Bool("verbose", false, "verbose output (default: disabled)")
Jesse Gross's avatar
Jesse Gross committed
828
829
	_ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
	_ = fs.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing")
830
	tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
831
	multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
832

833
	var lpaths multiLPath
834
	fs.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
835

836
837
838
839
840
841
	fs.Usage = func() {
		fmt.Fprintf(fs.Output(), "Runner usage\n")
		fs.PrintDefaults()
	}
	if err := fs.Parse(args); err != nil {
		return err
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
	}
	level := slog.LevelInfo
	if *verbose {
		level = slog.LevelDebug
	}
	handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
		Level:     level,
		AddSource: true,
		ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
			if attr.Key == slog.SourceKey {
				source := attr.Value.Any().(*slog.Source)
				source.File = filepath.Base(source.File)
			}
			return attr
		},
	})
	slog.SetDefault(slog.New(handler))
Jesse Gross's avatar
Jesse Gross committed
859
	slog.Info("starting ollama engine")
860
861
862

	server := &Server{
		batchSize: *batchSize,
863
		status:    llm.ServerStatusLoadingModel,
864
865
	}

Jesse Gross's avatar
Jesse Gross committed
866
867
868
869
	// TODO(jessegross): Parameters that need to be implemented:
	//	no-mmap
	//	mlock

870
	var tensorSplitFloats []float32
871
	if *tensorSplit != "" {
872
873
874
		splits := strings.Split(*tensorSplit, ",")
		tensorSplitFloats = make([]float32, len(splits))
		for i, s := range splits {
875
			f, _ := strconv.ParseFloat(s, 32)
876
			tensorSplitFloats[i] = float32(f)
877
		}
878
879
880
	}

	params := ml.BackendParams{
881
882
883
		Progress: func(progress float32) {
			server.progress = progress
		},
884
885
886
887
888
		NumThreads:     *threads,
		NumGPULayers:   *numGPULayers,
		MainGPU:        *mainGPU,
		TensorSplit:    tensorSplitFloats,
		FlashAttention: *flashAttention,
889
	}
890
891
892

	server.ready.Add(1)
	ctx, cancel := context.WithCancel(context.Background())
Michael Yang's avatar
Michael Yang committed
893
894
	defer cancel()

895
896
897
898
	go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)

	server.cond = sync.NewCond(&server.mu)

899
900
901
902
903
904
	go server.run(ctx)

	addr := "127.0.0.1:" + strconv.Itoa(*port)
	listener, err := net.Listen("tcp", addr)
	if err != nil {
		fmt.Println("Listen error:", err)
905
		return err
906
907
908
909
	}
	defer listener.Close()

	mux := http.NewServeMux()
910
911
912
913
914
915
916
	// TODO: support embeddings
	mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
		http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
	})

	mux.HandleFunc("POST /completion", server.completion)
	mux.HandleFunc("GET /health", server.health)
917
918
919
920
921
922
923
924

	httpServer := http.Server{
		Handler: mux,
	}

	log.Println("Server listening on", addr)
	if err := httpServer.Serve(listener); err != nil {
		log.Fatal("server error:", err)
925
		return err
926
927
	}

928
	return nil
929
}