runner.go 25 KB
Newer Older
Jesse Gross's avatar
Jesse Gross committed
1
package llamarunner
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

import (
	"context"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
	"log"
	"log/slog"
	"net"
	"net/http"
	"os"
	"regexp"
	"strconv"
	"strings"
	"sync"
	"time"
19
	"unicode/utf8"
20

21
22
	"golang.org/x/sync/semaphore"

23
	"github.com/ollama/ollama/api"
24
	"github.com/ollama/ollama/envconfig"
25
	"github.com/ollama/ollama/llama"
26
	"github.com/ollama/ollama/llm"
27
	"github.com/ollama/ollama/logutil"
Jesse Gross's avatar
Jesse Gross committed
28
	"github.com/ollama/ollama/runner/common"
29
30
)

31
32
33
34
35
36
// response contains a piece of generated text along with optional logprobs
type response struct {
	content  string
	logprobs []llm.Logprob
}

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
// input is an element of the prompt to process, either
// a token or an image embedding (generated from a vision projector)
type input struct {
	token int

	// embed is an image embedding
	embed []float32
}

type Sequence struct {
	// batch index
	iBatch int

	// number of tokens predicted so far
	numPredicted int

	// prompt inputs left to evaluate
	inputs []input

56
57
58
	// inputs that have been added to a batch but not yet submitted to Decode
	pendingInputs []input

59
60
61
	// tokens that have been generated but not returned yet (e.g. for stop sequences)
	pendingResponses []string

62
63
64
	// logprobs for tokens that haven't been returned yet
	pendingLogprobs []llm.Logprob

65
66
67
68
	// input cache being used by this sequence
	cache *InputCacheSlot

	// channel to send responses over
69
	responses chan response
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

	// channel to stop decoding (such as if the remote connection is closed)
	quit chan bool

	// number of tokens to predict
	numPredict int

	samplingCtx *llama.SamplingContext

	// channel to send back the embedding if embedding only
	embedding chan []float32

	// stop sequences
	stop []string

	// number of inputs to keep at the beginning when shifting context window
	numKeep int

	// true if an embedding are to be returned instead of text generation
	embeddingOnly bool

91
92
93
	// shift if context window is exceeded
	shift bool

94
	doneReason llm.DoneReason
95

96
97
98
99
	// logprobs configuration
	logprobs    bool
	topLogprobs int

100
	// Metrics
Michael Yang's avatar
Michael Yang committed
101
102
103
104
	processingDuration time.Duration
	generationDuration time.Duration
	numDecoded         int
	numPromptInputs    int
105
106
107
108
109
110
111
112
}

type NewSequenceParams struct {
	numPredict     int
	stop           []string
	numKeep        int
	samplingParams *llama.SamplingParams
	embedding      bool
113
114
	shift          bool
	truncate       bool
115
116
	logprobs       bool
	topLogprobs    int
117
118
}

119
120
var errorInputTooLong = errors.New("the input length exceeds the context length")

121
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
122
123
124
125
126
127
128
129
130
131
132
133
134
	s.ready.Wait()

	inputs, err := s.inputs(prompt, images)
	if err != nil {
		return nil, fmt.Errorf("failed to process inputs: %w", err)
	} else if len(inputs) == 0 {
		return nil, errors.New("no input provided")
	}

	if params.numKeep < 0 {
		params.numKeep = len(inputs)
	}

135
136
	if s.model.AddBOSToken() {
		params.numKeep += 1
137
138
	}

139
140
141
	// Ensure that at least 1 input can be discarded during shift
	params.numKeep = min(params.numKeep, s.cache.numCtx-1)

142
	if len(inputs) > s.cache.numCtx {
143
		discard := len(inputs) - s.cache.numCtx
144
145
146
147
		if !params.truncate {
			return nil, errorInputTooLong
		}

148
		newInputs := inputs[:params.numKeep]
149
150
151
		newInputs = append(newInputs, inputs[params.numKeep+discard:]...)

		slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
152
		inputs = newInputs
153
154
155
156
	}

	var sc *llama.SamplingContext
	if params.samplingParams != nil {
Jesse Gross's avatar
Jesse Gross committed
157
158
159
160
		sc, err = llama.NewSamplingContext(s.model, *params.samplingParams)
		if err != nil {
			return nil, err
		}
161
162
		for _, input := range inputs {
			if input.embed == nil {
163
				sc.Accept(input.token, false)
164
165
166
167
168
			}
		}
	}

	return &Sequence{
Michael Yang's avatar
Michael Yang committed
169
170
171
172
		inputs:           inputs,
		numPromptInputs:  len(inputs),
		numPredict:       params.numPredict,
		pendingResponses: make([]string, 0),
173
		responses:        make(chan response, 100),
Michael Yang's avatar
Michael Yang committed
174
175
176
177
178
179
		quit:             make(chan bool, 1),
		embedding:        make(chan []float32, 1),
		samplingCtx:      sc,
		embeddingOnly:    params.embedding,
		stop:             params.stop,
		numKeep:          params.numKeep,
180
		shift:            params.shift,
181
182
		logprobs:         params.logprobs,
		topLogprobs:      params.topLogprobs,
183
184
185
	}, nil
}

186
187
188
189
190
// calculateLogprobsLlama converts raw logits to log probabilities and finds top K tokens
func calculateLogprobsLlama(logits []float32, selectedToken int, topK int, model *llama.Model) []llm.Logprob {
	return common.CalculateLogprobs(logits, selectedToken, topK, model.TokenToPiece)
}

191
192
193
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// generating image embeddings for each image
194
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) {
195
	var inputs []input
196
197
198
199
200
201
202
203
204
205
	var parts []string
	var matches [][]string

	if s.image != nil {
		re := regexp.MustCompile(`\[img-(\d+)\]`)
		parts = re.Split(prompt, -1)
		matches = re.FindAllStringSubmatch(prompt, -1)
	} else {
		parts = []string{prompt}
	}
206
207
208

	for i, part := range parts {
		// text - tokenize
209
210
211
212
		tokens, err := s.lc.Model().Tokenize(part, i == 0, true)
		if err != nil {
			return nil, err
		}
213

214
215
		for _, t := range tokens {
			inputs = append(inputs, input{token: t})
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
		}

		// image - generate image embedding
		if i < len(matches) {
			n, _ := strconv.Atoi(matches[i][1])

			imageIndex := -1
			for j := range images {
				if images[j].ID == n {
					imageIndex = j
					break
				}
			}

			if imageIndex < 0 {
				return nil, fmt.Errorf("invalid image index: %d", n)
			}

234
			chunks, err := s.image.MultimodalTokenize(s.lc, images[imageIndex].Data)
Jesse Gross's avatar
Jesse Gross committed
235
236
237
238
			if err != nil {
				return nil, err
			}

239
240
241
242
243
244
245
246
			for _, c := range chunks {
				if len(c.Embed) != 0 {
					inputs = append(inputs, input{embed: c.Embed})
				} else {
					for _, t := range c.Tokens {
						inputs = append(inputs, input{token: t})
					}
				}
247
248
249
250
251
252
253
254
			}
		}
	}

	return inputs, nil
}

type Server struct {
Jesse Gross's avatar
Jesse Gross committed
255
256
257
258
259
260
	// modelPath is the location of the model to be loaded
	modelPath string

	// loadMu prevents more than one load attempt from occurring at a time
	loadMu sync.Mutex

261
262
263
264
265
	// is the server ready to process requests?
	// protects access to model and image
	ready sync.WaitGroup

	// loaded model
266
267
	model *llama.Model

268
	// image model context for multi-modal models
269
	image *ImageContext
270

271
	// status for external health reporting - loading, ready to serve, etc.
272
	status llm.ServerStatus
273
274
275
276
277
278
279
280

	// current progress on loading the model
	progress float32

	// number of simultaneous requests to handle
	parallel int

	// maximum number of elements in a batch (per sequence)
281
	// TODO (jmorganca): make this n_batch
282
283
	batchSize int

284
285
286
287
288
289
290
291
292
	// protects access to everything below this line
	// this is context state needed for decoding
	mu sync.Mutex

	// indicates that data is ready for processing
	cond *sync.Cond

	// decoding state
	lc *llama.Context
293

294
	// the list of simultaneous sequences being evaluated
295
296
	seqs []*Sequence

297
298
299
300
	// seqs can have a maximum of parallel entries, which
	// is enfoced by seqSem
	seqsSem *semaphore.Weighted

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
	// KV cache
	cache *InputCache

	// next sequence for prompt processing to avoid starvation
	nextSeq int
}

func (s *Server) allNil() bool {
	for _, item := range s.seqs {
		if item != nil {
			return false
		}
	}
	return true
}

func flushPending(seq *Sequence) bool {
318
	joined := strings.Join(seq.pendingResponses, "")
319
	logprobs := seq.pendingLogprobs
320
	seq.pendingResponses = []string{}
321
	seq.pendingLogprobs = []llm.Logprob{}
322
323
324
325
326
327
328
329
330

	// Check if there are any partial UTF-8 characters remaining.
	// We already check and queue as we are generating but some may
	// still make it here:
	// - Sequence is ending, e.g. generation limit has been hit
	// - Invalid characters in the middle of a string
	// This is a stricter check to ensure we never output invalid Unicode.
	for !utf8.ValidString(joined) {
		joined = joined[:len(joined)-1]
331
332
	}

333
334
335
336
337
	if len(joined) == 0 {
		return true
	}

	select {
338
	case seq.responses <- response{content: joined, logprobs: logprobs}:
339
340
341
342
		return true
	case <-seq.quit:
		return false
	}
343
344
}

345
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
346
347
348
349
350
351
352
353
	seq := s.seqs[seqIndex]

	flushPending(seq)
	seq.doneReason = reason
	close(seq.responses)
	close(seq.embedding)
	seq.cache.InUse = false
	s.seqs[seqIndex] = nil
354
	s.seqsSem.Release(1)
355
356
357
358
359
}

func (s *Server) run(ctx context.Context) {
	s.ready.Wait()

360
	// Logically these batches are used only within the context of processBatch
361
	// but it is better for performance to allocate them once here
Jesse Gross's avatar
Jesse Gross committed
362
363
364
365
	tokenBatch, err := llama.NewBatch(s.batchSize, len(s.seqs), 0)
	if err != nil {
		panic(err)
	}
366
367
	defer tokenBatch.Free()

368
369
370
	var embedBatch *llama.Batch
	embedBatchSize := s.image.BatchSize(s.batchSize)
	if embedBatchSize != 0 {
Jesse Gross's avatar
Jesse Gross committed
371
372
373
374
		embedBatch, err = llama.NewBatch(embedBatchSize, len(s.seqs), s.image.EmbedSize(s.lc))
		if err != nil {
			panic(err)
		}
375
376
377
378
		defer embedBatch.Free()
	} else {
		embedBatch = &llama.Batch{}
	}
379
380
381
382
383
384

	for {
		select {
		case <-ctx.Done():
			return
		default:
385
386
387
388
389
			err := s.processBatch(tokenBatch, embedBatch)
			if err != nil {
				panic(err)
			}

390
391
392
393
394
395
396
397
398
399
400
401
402
			tokenBatch.Clear()
			embedBatch.Clear()
		}
	}
}

// TODO (jmorganca): processBatch should be simplified, removing:
// * sampling
// * stop token checking
// * metrics
// these should instead be handled by the handlers
// it should only be responsible for accepting tokens or embeddings and
// processing batches as fast as possible
403
func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) error {
404
405
406
407
408
409
410
	s.mu.Lock()
	for s.allNil() {
		s.cond.Wait() // Wait until an item is added
	}
	defer s.mu.Unlock()

	var batch *llama.Batch
411
	var numOutputs int
412
413
414
415
416
417
418
419
420
421
422

	seqIdx := s.nextSeq - 1
	for range s.seqs {
		seqIdx = (seqIdx + 1) % len(s.seqs)
		seq := s.seqs[seqIdx]

		if seq == nil {
			continue
		}

		// if past the num predict limit
423
		if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
424
			s.removeSequence(seqIdx, llm.DoneReasonLength)
425
426
427
428
			continue
		}

		for i, input := range seq.inputs {
429
430
			if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
				if len(seq.pendingInputs) == 0 {
431
432
433
434
435
					if !seq.shift {
						s.removeSequence(seqIdx, llm.DoneReasonLength)
						break
					}

436
437
					err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
					if err != nil {
438
439
440
441
442
443
444
445
446
						var reprocess *ErrReprocessInputs
						if errors.As(err, &reprocess) {
							// Prepend these inputs to the sequence's inputs queue for reprocessing
							seq.inputs = append(reprocess.Inputs, seq.inputs...)
							// Continue processing as normal
							continue
						} else {
							return err
						}
447
					}
448
449
450
451
452
				} else {
					break
				}
			}

453
454
455
456
457
458
459
460
461
462
463
464
			embedding := input.embed != nil

			// If we don't currently have a batch, use one of the correct type and
			// fill it up as much as possible across all sequences. If we encounter an
			// input of the opppsite type, stop for that sequence but then pick up from
			// there for the next batch, ensuring that we alternate types
			if batch == nil {
				if !embedding {
					batch = tokenBatch
				} else {
					batch = embedBatch
				}
465
			} else if embedding != batch.IsEmbedding() {
466
467
468
469
				s.nextSeq = seqIdx
				break
			}

470
			if i >= batch.Size() {
471
472
473
				break
			}

474
475
476
477
478
479
			output := i+1 == len(seq.inputs)
			batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), output, seq.cache.Id)
			if output {
				numOutputs++
			}

480
			seq.pendingInputs = append(seq.pendingInputs, input)
481
482
			seq.iBatch = batch.NumTokens() - 1
		}
483
484

		seq.inputs = seq.inputs[len(seq.pendingInputs):]
485
486
487
	}

	if batch == nil || batch.NumTokens() == 0 {
488
		return nil
489
490
	}

Michael Yang's avatar
Michael Yang committed
491
492
	t := time.Now()
	if err := s.lc.Decode(batch); err != nil {
493
		return fmt.Errorf("failed to decode batch: %w", err)
494
495
	}

496
497
498
499
	if numOutputs > 0 {
		s.lc.Synchronize()
	}

500
501
502
503
504
	for i, seq := range s.seqs {
		if seq == nil {
			continue
		}

505
506
507
508
509
510
		// After calling Decode, pending inputs are now in the cache
		if len(seq.pendingInputs) > 0 {
			seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
			seq.pendingInputs = []input{}
		}

511
512
		// don't sample prompt processing
		if len(seq.inputs) != 0 {
513
			seq.processingDuration += time.Since(t)
514
515
516
			continue
		}

Michael Yang's avatar
Michael Yang committed
517
518
519
520
521
		seq.numDecoded++
		if seq.numDecoded > 1 {
			seq.generationDuration += time.Since(t)
		} else {
			seq.processingDuration += time.Since(t)
522
523
524
525
		}

		// if done processing the prompt, generate an embedding and return
		if seq.embeddingOnly {
526
			embed := s.lc.GetEmbeddingsSeq(seq.cache.Id)
527
528
529
530
531
			if embed == nil {
				embed = s.lc.GetEmbeddingsIth(seq.iBatch)
			}

			seq.embedding <- embed
532
			s.removeSequence(i, llm.DoneReasonStop)
533
534
535
536
			continue
		}

		// sample a token
537
538
		token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
		seq.samplingCtx.Accept(token, true)
539
540
541
542
543
544
545
546
547
548
		piece := s.model.TokenToPiece(token)

		seq.numPredicted++

		// if it's an end of sequence token, break
		if s.model.TokenIsEog(token) {
			// TODO (jmorganca): we should send this back
			// as it's important for the /api/generate context
			// seq.responses <- piece

549
			s.removeSequence(i, llm.DoneReasonStop)
550
551
552
			continue
		}

553
554
555
556
557
558
559
560
561
		// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
		if seq.logprobs {
			logits := s.lc.GetLogitsIth(seq.iBatch)
			if logits != nil {
				logprobs := calculateLogprobsLlama(logits, token, seq.topLogprobs, s.model)
				seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
			}
		}

562
563
564
565
566
		seq.inputs = []input{{token: token}}

		seq.pendingResponses = append(seq.pendingResponses, piece)
		sequence := strings.Join(seq.pendingResponses, "")

Jesse Gross's avatar
Jesse Gross committed
567
		if ok, stop := common.FindStop(sequence, seq.stop); ok {
568
569
570
571
			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)

			var tokenTruncated bool
			origLen := len(seq.pendingResponses)
Jesse Gross's avatar
Jesse Gross committed
572
			seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop)
573
574
			newLen := len(seq.pendingResponses)

575
576
577
578
579
580
581
582
583
584
585
			// Truncate logprobs to match the truncated responses
			if seq.logprobs {
				origLogprobsLen := len(seq.pendingLogprobs)
				numTokensRemoved := origLen - newLen
				newLogprobsLen := origLogprobsLen - numTokensRemoved
				if newLogprobsLen < 0 {
					newLogprobsLen = 0
				}
				seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen]
			}

586
587
588
589
590
591
592
593
594
595
596
597
598
			// Update the cache based on the tokens that will be returned:
			// - We have 1 token more than is currently in the cache because
			// the last one generated wasn't submitted to Decode
			// - Remove any stop sequences that we stripped out
			// - If truncateStop removed a portion of a token, drop that
			// - As defense-in-depth, if truncatedToken didn't find a stop token
			// remove the extra one that we added to the cache len
			tokenLen := len(seq.cache.Inputs) + 1
			tokenLen -= origLen - newLen
			if tokenTruncated || origLen == newLen {
				tokenLen--
			}
			seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
599

600
			s.removeSequence(i, llm.DoneReasonStop)
601
602
603
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
604
		if common.ContainsStopSuffix(sequence, seq.stop) {
605
606
607
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
608
		if common.IncompleteUnicode(sequence) {
609
610
611
612
			continue
		}

		if !flushPending(seq) {
613
			s.removeSequence(i, llm.DoneReasonConnectionClosed)
614
615
		}
	}
616
617

	return nil
618
619
620
}

func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
621
	var req llm.CompletionRequest
622
623
624
625
626
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, "Bad request", http.StatusBadRequest)
		return
	}

627
628
629
630
631
	if req.Options == nil {
		opts := api.DefaultOptions()
		req.Options = &opts
	}

632
633
634
635
636
637
638
639
640
641
	// Set the headers to indicate streaming
	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Transfer-Encoding", "chunked")

	flusher, ok := w.(http.Flusher)
	if !ok {
		http.Error(w, "Streaming not supported", http.StatusInternalServerError)
		return
	}

642
643
644
645
646
647
648
649
650
651
652
653
654
655
	// Extract options from the CompletionRequest
	samplingParams := llama.SamplingParams{
		TopK:           req.Options.TopK,
		TopP:           req.Options.TopP,
		MinP:           req.Options.MinP,
		TypicalP:       req.Options.TypicalP,
		Temp:           req.Options.Temperature,
		RepeatLastN:    req.Options.RepeatLastN,
		PenaltyRepeat:  req.Options.RepeatPenalty,
		PenaltyFreq:    req.Options.FrequencyPenalty,
		PenaltyPresent: req.Options.PresencePenalty,
		Seed:           uint32(req.Options.Seed),
		Grammar:        req.Grammar,
	}
656
657

	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
658
659
660
		numPredict:     req.Options.NumPredict,
		stop:           req.Options.Stop,
		numKeep:        req.Options.NumKeep,
661
662
		samplingParams: &samplingParams,
		embedding:      false,
663
664
		shift:          req.Shift,
		truncate:       req.Truncate,
665
666
		logprobs:       req.Logprobs,
		topLogprobs:    req.TopLogprobs,
667
668
	})
	if err != nil {
669
670
671
672
		if errors.Is(err, errorInputTooLong) {
			http.Error(w, err.Error(), http.StatusBadRequest)
			return
		}
673
674
675
676
		http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
		return
	}

677
	// Ensure there is a place to put the sequence, released when removed from s.seqs
678
	if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
679
680
681
		if errors.Is(err, context.Canceled) {
			slog.Info("aborting completion request due to client closing the connection")
		} else {
682
			http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
683
		}
684
685
686
		return
	}

687
	s.mu.Lock()
688
	found := false
689
690
	for i, sq := range s.seqs {
		if sq == nil {
691
			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
692
693
			if err != nil {
				s.mu.Unlock()
694
				s.seqsSem.Release(1)
695
696
697
				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
				return
			}
698

699
700
			s.seqs[i] = seq
			s.cond.Signal()
701
			found = true
702
703
704
705
706
			break
		}
	}
	s.mu.Unlock()

707
	if !found {
708
		s.seqsSem.Release(1)
709
710
711
712
		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
		return
	}

713
714
715
716
717
	for {
		select {
		case <-r.Context().Done():
			close(seq.quit)
			return
718
		case resp, ok := <-seq.responses:
719
			if ok {
720
				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
721
722
					Content:  resp.content,
					Logprobs: resp.logprobs,
723
724
725
726
727
728
729
730
				}); err != nil {
					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
					close(seq.quit)
					return
				}

				flusher.Flush()
			} else {
731
732
				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
					Done:               true,
733
					DoneReason:         seq.doneReason,
734
					PromptEvalCount:    seq.numPromptInputs,
Michael Yang's avatar
Michael Yang committed
735
					PromptEvalDuration: seq.processingDuration,
736
					EvalCount:          seq.numDecoded,
Michael Yang's avatar
Michael Yang committed
737
					EvalDuration:       seq.generationDuration,
738
739
740
741
742
743
744
745
746
747
748
				}); err != nil {
					http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
				}

				return
			}
		}
	}
}

func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
749
	var req llm.EmbeddingRequest
750
751
752
753
754
755
756
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
		return
	}

	w.Header().Set("Content-Type", "application/json")

757
758
	seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
		embedding: true,
759
760
761
762
763

		// TODO (jmorganca): this should be provided by the server via the
		// request options and truncated here in the runner, instead of relying on
		// the server's truncate logic
		truncate: true,
764
	})
765
766
767
768
769
	if err != nil {
		http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
		return
	}

770
	// Ensure there is a place to put the sequence, released when removed from s.seqs
771
	if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
772
773
774
		if errors.Is(err, context.Canceled) {
			slog.Info("aborting embeddings request due to client closing the connection")
		} else {
775
			http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
776
		}
777
778
779
		return
	}

780
	s.mu.Lock()
781
	found := false
782
783
	for i, sq := range s.seqs {
		if sq == nil {
784
			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
785
786
			if err != nil {
				s.mu.Unlock()
787
				s.seqsSem.Release(1)
788
789
790
791
792
				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
				return
			}
			s.seqs[i] = seq
			s.cond.Signal()
793
			found = true
794
795
796
797
798
			break
		}
	}
	s.mu.Unlock()

799
	if !found {
800
		s.seqsSem.Release(1)
801
802
803
804
		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
		return
	}

805
806
	embedding := <-seq.embedding

807
	if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
808
		Embedding: embedding,
809
810
811
812
813
814
815
	}); err != nil {
		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
	}
}

func (s *Server) health(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "application/json")
816
817
	if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
		Status:   s.status,
818
819
820
821
822
823
		Progress: s.progress,
	}); err != nil {
		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
	}
}

Jesse Gross's avatar
Jesse Gross committed
824
825
// loadModel allocates memory based on the given parameters and loads the weights. The
// memory allocated is worst case for text models but not for vision.
826
827
828
func (s *Server) loadModel(
	params llama.ModelParams,
	mpath string,
Jesse Gross's avatar
Jesse Gross committed
829
	lpath []string,
830
831
	ppath string,
	kvSize int,
832
	kvCacheType string,
833
834
835
836
	flashAttention bool,
	threads int,
	multiUserCache bool,
) {
837
838
839
840
841
	var err error
	s.model, err = llama.LoadModelFromFile(mpath, params)
	if err != nil {
		panic(err)
	}
842

843
	ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType)
844
845
846
847
	s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
	if err != nil {
		panic(err)
	}
848

Jesse Gross's avatar
Jesse Gross committed
849
850
851
852
	for _, path := range lpath {
		err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads)
		if err != nil {
			panic(err)
853
854
855
856
		}
	}

	if ppath != "" {
857
		var err error
858
		s.image, err = NewImageContext(s.lc, ppath)
859
860
861
		if err != nil {
			panic(err)
		}
862
863
	}

864
865
866
867
	s.cache, err = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache)
	if err != nil {
		panic(err)
	}
868

869
	s.status = llm.ServerStatusReady
870
871
872
	s.ready.Done()
}

Jesse Gross's avatar
Jesse Gross committed
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
// load is the handler called by the Ollama server to process different
// load operations
func (s *Server) load(w http.ResponseWriter, r *http.Request) {
	s.loadMu.Lock()
	defer s.loadMu.Unlock()

	w.Header().Set("Content-Type", "application/json")

	if s.status != llm.ServerStatusLaunched {
		http.Error(w, "model already loaded", http.StatusInternalServerError)
		return
	}

	var req llm.LoadRequest
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, "bad request", http.StatusBadRequest)
		return
	}

	slog.Info("load", "request", req)

	switch req.Operation {
	// LoadOperationFit and LoadOperationAlloc have no meaning here - just return a successful response

	case llm.LoadOperationCommit:
		s.batchSize = req.BatchSize
		s.parallel = req.Parallel
		s.seqs = make([]*Sequence, s.parallel)
		s.seqsSem = semaphore.NewWeighted(int64(s.parallel))

		gpuIDs := llama.EnumerateGPUs()
		tensorSplit := make([]float32, len(gpuIDs))
		numGPU := 0
		for i := range gpuIDs {
			for _, layers := range req.GPULayers {
908
				if gpuIDs[i] == layers.DeviceID {
Jesse Gross's avatar
Jesse Gross committed
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
					tensorSplit[i] = float32(len(layers.Layers))
					numGPU += len(layers.Layers)
				}
			}
		}

		params := llama.ModelParams{
			NumGpuLayers: numGPU,
			MainGpu:      req.MainGPU,
			UseMmap:      req.UseMmap && len(req.LoraPath) == 0,
			TensorSplit:  tensorSplit,
			Progress: func(progress float32) {
				s.progress = progress
			},
		}

		s.status = llm.ServerStatusLoadingModel
		go s.loadModel(params, s.modelPath, req.LoraPath, req.ProjectorPath, req.KvSize, req.KvCacheType, req.FlashAttention, req.NumThreads, req.MultiUserCache)

	case llm.LoadOperationClose:
		// No-op for us
		if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil {
			http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
		}
		return
	}

	resp := llm.LoadResponse{Success: true}
	if err := json.NewEncoder(w).Encode(&resp); err != nil {
		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
		return
	}
}

943
944
945
946
func Execute(args []string) error {
	fs := flag.NewFlagSet("runner", flag.ExitOnError)
	mpath := fs.String("model", "", "Path to model binary file")
	port := fs.Int("port", 8080, "Port to expose the server on")
947
	_ = fs.Bool("verbose", false, "verbose output (default: disabled)")
948

949
950
951
952
953
954
	fs.Usage = func() {
		fmt.Fprintf(fs.Output(), "Runner usage\n")
		fs.PrintDefaults()
	}
	if err := fs.Parse(args); err != nil {
		return err
955
	}
956
	slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
957
	slog.Info("starting go runner")
958
959

	llama.BackendInit()
960
961

	server := &Server{
Jesse Gross's avatar
Jesse Gross committed
962
963
		modelPath: *mpath,
		status:    llm.ServerStatusLaunched,
964
965
966
967
968
969
970
	}

	server.ready.Add(1)

	server.cond = sync.NewCond(&server.mu)

	ctx, cancel := context.WithCancel(context.Background())
Michael Yang's avatar
Michael Yang committed
971
972
	defer cancel()

973
974
975
976
977
978
	go server.run(ctx)

	addr := "127.0.0.1:" + strconv.Itoa(*port)
	listener, err := net.Listen("tcp", addr)
	if err != nil {
		fmt.Println("Listen error:", err)
979
		return err
980
981
982
983
	}
	defer listener.Close()

	mux := http.NewServeMux()
Jesse Gross's avatar
Jesse Gross committed
984
	mux.HandleFunc("POST /load", server.load)
985
986
987
988
989
990
991
992
993
994
995
	mux.HandleFunc("/embedding", server.embeddings)
	mux.HandleFunc("/completion", server.completion)
	mux.HandleFunc("/health", server.health)

	httpServer := http.Server{
		Handler: mux,
	}

	log.Println("Server listening on", addr)
	if err := httpServer.Serve(listener); err != nil {
		log.Fatal("server error:", err)
996
		return err
997
998
	}

999
	return nil
1000
}