runner.go 25.2 KB
Newer Older
Jesse Gross's avatar
Jesse Gross committed
1
package llamarunner
2
3
4
5
6
7
8
9
10
11
12
13
14

import (
	"context"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
	"log"
	"log/slog"
	"net"
	"net/http"
	"os"
	"regexp"
15
	"sort"
16
17
18
19
	"strconv"
	"strings"
	"sync"
	"time"
20
	"unicode/utf8"
21

22
23
	"golang.org/x/sync/semaphore"

24
	"github.com/ollama/ollama/api"
25
	"github.com/ollama/ollama/envconfig"
26
	"github.com/ollama/ollama/llama"
27
	"github.com/ollama/ollama/llm"
28
	"github.com/ollama/ollama/logutil"
Jesse Gross's avatar
Jesse Gross committed
29
	"github.com/ollama/ollama/runner/common"
30
31
)

32
33
34
35
36
37
// response contains a piece of generated text along with optional logprobs
type response struct {
	content  string
	logprobs []llm.Logprob
}

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
// input is an element of the prompt to process, either
// a token or an image embedding (generated from a vision projector)
type input struct {
	token int

	// embed is an image embedding
	embed []float32
}

type Sequence struct {
	// batch index
	iBatch int

	// number of tokens predicted so far
	numPredicted int

	// prompt inputs left to evaluate
	inputs []input

57
58
59
	// inputs that have been added to a batch but not yet submitted to Decode
	pendingInputs []input

60
61
62
	// tokens that have been generated but not returned yet (e.g. for stop sequences)
	pendingResponses []string

63
64
65
	// logprobs for tokens that haven't been returned yet
	pendingLogprobs []llm.Logprob

66
67
68
69
	// input cache being used by this sequence
	cache *InputCacheSlot

	// channel to send responses over
70
	responses chan response
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

	// channel to stop decoding (such as if the remote connection is closed)
	quit chan bool

	// number of tokens to predict
	numPredict int

	samplingCtx *llama.SamplingContext

	// channel to send back the embedding if embedding only
	embedding chan []float32

	// stop sequences
	stop []string

	// number of inputs to keep at the beginning when shifting context window
	numKeep int

	// true if an embedding are to be returned instead of text generation
	embeddingOnly bool

92
93
94
	// shift if context window is exceeded
	shift bool

95
	doneReason llm.DoneReason
96

97
98
99
100
	// logprobs configuration
	logprobs    bool
	topLogprobs int

101
	// Metrics
Michael Yang's avatar
Michael Yang committed
102
103
104
105
	processingDuration time.Duration
	generationDuration time.Duration
	numDecoded         int
	numPromptInputs    int
106
107
108
109
110
111
112
113
}

type NewSequenceParams struct {
	numPredict     int
	stop           []string
	numKeep        int
	samplingParams *llama.SamplingParams
	embedding      bool
114
115
	shift          bool
	truncate       bool
116
117
	logprobs       bool
	topLogprobs    int
118
119
}

120
121
var errorInputTooLong = errors.New("the input length exceeds the context length")

122
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
123
124
125
126
127
128
129
130
131
132
133
134
135
	s.ready.Wait()

	inputs, err := s.inputs(prompt, images)
	if err != nil {
		return nil, fmt.Errorf("failed to process inputs: %w", err)
	} else if len(inputs) == 0 {
		return nil, errors.New("no input provided")
	}

	if params.numKeep < 0 {
		params.numKeep = len(inputs)
	}

136
137
	if s.model.AddBOSToken() {
		params.numKeep += 1
138
139
	}

140
141
142
	// Ensure that at least 1 input can be discarded during shift
	params.numKeep = min(params.numKeep, s.cache.numCtx-1)

143
	if len(inputs) > s.cache.numCtx {
144
		discard := len(inputs) - s.cache.numCtx
145
146
147
148
		if !params.truncate {
			return nil, errorInputTooLong
		}

149
		newInputs := inputs[:params.numKeep]
150
151
152
		newInputs = append(newInputs, inputs[params.numKeep+discard:]...)

		slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
153
		inputs = newInputs
154
155
156
157
	}

	var sc *llama.SamplingContext
	if params.samplingParams != nil {
Jesse Gross's avatar
Jesse Gross committed
158
159
160
161
		sc, err = llama.NewSamplingContext(s.model, *params.samplingParams)
		if err != nil {
			return nil, err
		}
162
163
		for _, input := range inputs {
			if input.embed == nil {
164
				sc.Accept(input.token, false)
165
166
167
168
169
			}
		}
	}

	return &Sequence{
Michael Yang's avatar
Michael Yang committed
170
171
172
173
		inputs:           inputs,
		numPromptInputs:  len(inputs),
		numPredict:       params.numPredict,
		pendingResponses: make([]string, 0),
174
		responses:        make(chan response, 100),
Michael Yang's avatar
Michael Yang committed
175
176
177
178
179
180
		quit:             make(chan bool, 1),
		embedding:        make(chan []float32, 1),
		samplingCtx:      sc,
		embeddingOnly:    params.embedding,
		stop:             params.stop,
		numKeep:          params.numKeep,
181
		shift:            params.shift,
182
183
		logprobs:         params.logprobs,
		topLogprobs:      params.topLogprobs,
184
185
186
	}, nil
}

187
188
189
190
191
// calculateLogprobsLlama converts raw logits to log probabilities and finds top K tokens
func calculateLogprobsLlama(logits []float32, selectedToken int, topK int, model *llama.Model) []llm.Logprob {
	return common.CalculateLogprobs(logits, selectedToken, topK, model.TokenToPiece)
}

192
193
194
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// generating image embeddings for each image
195
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) {
196
	var inputs []input
197
198
199
200
201
202
203
204
205
206
	var parts []string
	var matches [][]string

	if s.image != nil {
		re := regexp.MustCompile(`\[img-(\d+)\]`)
		parts = re.Split(prompt, -1)
		matches = re.FindAllStringSubmatch(prompt, -1)
	} else {
		parts = []string{prompt}
	}
207
208
209

	for i, part := range parts {
		// text - tokenize
210
211
212
213
		tokens, err := s.lc.Model().Tokenize(part, i == 0, true)
		if err != nil {
			return nil, err
		}
214

215
216
		for _, t := range tokens {
			inputs = append(inputs, input{token: t})
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
		}

		// image - generate image embedding
		if i < len(matches) {
			n, _ := strconv.Atoi(matches[i][1])

			imageIndex := -1
			for j := range images {
				if images[j].ID == n {
					imageIndex = j
					break
				}
			}

			if imageIndex < 0 {
				return nil, fmt.Errorf("invalid image index: %d", n)
			}

235
			chunks, err := s.image.MultimodalTokenize(s.lc, images[imageIndex].Data)
Jesse Gross's avatar
Jesse Gross committed
236
237
238
239
			if err != nil {
				return nil, err
			}

240
241
242
243
244
245
246
247
			for _, c := range chunks {
				if len(c.Embed) != 0 {
					inputs = append(inputs, input{embed: c.Embed})
				} else {
					for _, t := range c.Tokens {
						inputs = append(inputs, input{token: t})
					}
				}
248
249
250
251
252
253
254
255
			}
		}
	}

	return inputs, nil
}

type Server struct {
Jesse Gross's avatar
Jesse Gross committed
256
257
258
259
260
261
	// modelPath is the location of the model to be loaded
	modelPath string

	// loadMu prevents more than one load attempt from occurring at a time
	loadMu sync.Mutex

262
263
264
265
266
	// is the server ready to process requests?
	// protects access to model and image
	ready sync.WaitGroup

	// loaded model
267
268
	model *llama.Model

269
	// image model context for multi-modal models
270
	image *ImageContext
271

272
	// status for external health reporting - loading, ready to serve, etc.
273
	status llm.ServerStatus
274
275
276
277
278
279
280
281

	// current progress on loading the model
	progress float32

	// number of simultaneous requests to handle
	parallel int

	// maximum number of elements in a batch (per sequence)
282
	// TODO (jmorganca): make this n_batch
283
284
	batchSize int

285
286
287
288
289
290
291
292
293
	// protects access to everything below this line
	// this is context state needed for decoding
	mu sync.Mutex

	// indicates that data is ready for processing
	cond *sync.Cond

	// decoding state
	lc *llama.Context
294

295
	// the list of simultaneous sequences being evaluated
296
297
	seqs []*Sequence

298
299
300
301
	// seqs can have a maximum of parallel entries, which
	// is enfoced by seqSem
	seqsSem *semaphore.Weighted

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
	// KV cache
	cache *InputCache

	// next sequence for prompt processing to avoid starvation
	nextSeq int
}

func (s *Server) allNil() bool {
	for _, item := range s.seqs {
		if item != nil {
			return false
		}
	}
	return true
}

func flushPending(seq *Sequence) bool {
319
	joined := strings.Join(seq.pendingResponses, "")
320
	logprobs := seq.pendingLogprobs
321
	seq.pendingResponses = []string{}
322
	seq.pendingLogprobs = []llm.Logprob{}
323
324
325
326
327
328
329
330
331

	// Check if there are any partial UTF-8 characters remaining.
	// We already check and queue as we are generating but some may
	// still make it here:
	// - Sequence is ending, e.g. generation limit has been hit
	// - Invalid characters in the middle of a string
	// This is a stricter check to ensure we never output invalid Unicode.
	for !utf8.ValidString(joined) {
		joined = joined[:len(joined)-1]
332
333
	}

334
335
336
337
338
	if len(joined) == 0 {
		return true
	}

	select {
339
	case seq.responses <- response{content: joined, logprobs: logprobs}:
340
341
342
343
		return true
	case <-seq.quit:
		return false
	}
344
345
}

346
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
347
348
349
350
351
352
353
354
	seq := s.seqs[seqIndex]

	flushPending(seq)
	seq.doneReason = reason
	close(seq.responses)
	close(seq.embedding)
	seq.cache.InUse = false
	s.seqs[seqIndex] = nil
355
	s.seqsSem.Release(1)
356
357
358
359
360
}

func (s *Server) run(ctx context.Context) {
	s.ready.Wait()

361
	// Logically these batches are used only within the context of processBatch
362
	// but it is better for performance to allocate them once here
Jesse Gross's avatar
Jesse Gross committed
363
364
365
366
	tokenBatch, err := llama.NewBatch(s.batchSize, len(s.seqs), 0)
	if err != nil {
		panic(err)
	}
367
368
	defer tokenBatch.Free()

369
370
371
	var embedBatch *llama.Batch
	embedBatchSize := s.image.BatchSize(s.batchSize)
	if embedBatchSize != 0 {
Jesse Gross's avatar
Jesse Gross committed
372
373
374
375
		embedBatch, err = llama.NewBatch(embedBatchSize, len(s.seqs), s.image.EmbedSize(s.lc))
		if err != nil {
			panic(err)
		}
376
377
378
379
		defer embedBatch.Free()
	} else {
		embedBatch = &llama.Batch{}
	}
380
381
382
383
384
385

	for {
		select {
		case <-ctx.Done():
			return
		default:
386
387
388
389
390
			err := s.processBatch(tokenBatch, embedBatch)
			if err != nil {
				panic(err)
			}

391
392
393
394
395
396
397
398
399
400
401
402
403
			tokenBatch.Clear()
			embedBatch.Clear()
		}
	}
}

// TODO (jmorganca): processBatch should be simplified, removing:
// * sampling
// * stop token checking
// * metrics
// these should instead be handled by the handlers
// it should only be responsible for accepting tokens or embeddings and
// processing batches as fast as possible
404
func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) error {
405
406
407
408
409
410
411
	s.mu.Lock()
	for s.allNil() {
		s.cond.Wait() // Wait until an item is added
	}
	defer s.mu.Unlock()

	var batch *llama.Batch
412
	var numOutputs int
413
414
415
416
417
418
419
420
421
422
423

	seqIdx := s.nextSeq - 1
	for range s.seqs {
		seqIdx = (seqIdx + 1) % len(s.seqs)
		seq := s.seqs[seqIdx]

		if seq == nil {
			continue
		}

		// if past the num predict limit
424
		if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
425
			s.removeSequence(seqIdx, llm.DoneReasonLength)
426
427
428
429
			continue
		}

		for i, input := range seq.inputs {
430
431
			if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
				if len(seq.pendingInputs) == 0 {
432
433
434
435
436
					if !seq.shift {
						s.removeSequence(seqIdx, llm.DoneReasonLength)
						break
					}

437
438
					err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
					if err != nil {
439
440
441
442
443
444
445
446
447
						var reprocess *ErrReprocessInputs
						if errors.As(err, &reprocess) {
							// Prepend these inputs to the sequence's inputs queue for reprocessing
							seq.inputs = append(reprocess.Inputs, seq.inputs...)
							// Continue processing as normal
							continue
						} else {
							return err
						}
448
					}
449
450
451
452
453
				} else {
					break
				}
			}

454
455
456
457
458
459
460
461
462
463
464
465
			embedding := input.embed != nil

			// If we don't currently have a batch, use one of the correct type and
			// fill it up as much as possible across all sequences. If we encounter an
			// input of the opppsite type, stop for that sequence but then pick up from
			// there for the next batch, ensuring that we alternate types
			if batch == nil {
				if !embedding {
					batch = tokenBatch
				} else {
					batch = embedBatch
				}
466
			} else if embedding != batch.IsEmbedding() {
467
468
469
470
				s.nextSeq = seqIdx
				break
			}

471
			if i >= batch.Size() {
472
473
474
				break
			}

475
476
477
478
479
480
			output := i+1 == len(seq.inputs)
			batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), output, seq.cache.Id)
			if output {
				numOutputs++
			}

481
			seq.pendingInputs = append(seq.pendingInputs, input)
482
483
			seq.iBatch = batch.NumTokens() - 1
		}
484
485

		seq.inputs = seq.inputs[len(seq.pendingInputs):]
486
487
488
	}

	if batch == nil || batch.NumTokens() == 0 {
489
		return nil
490
491
	}

Michael Yang's avatar
Michael Yang committed
492
493
	t := time.Now()
	if err := s.lc.Decode(batch); err != nil {
494
		return fmt.Errorf("failed to decode batch: %w", err)
495
496
	}

497
498
499
500
	if numOutputs > 0 {
		s.lc.Synchronize()
	}

501
502
503
504
505
	for i, seq := range s.seqs {
		if seq == nil {
			continue
		}

506
507
508
509
510
511
		// After calling Decode, pending inputs are now in the cache
		if len(seq.pendingInputs) > 0 {
			seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
			seq.pendingInputs = []input{}
		}

512
513
		// don't sample prompt processing
		if len(seq.inputs) != 0 {
514
			seq.processingDuration += time.Since(t)
515
516
517
			continue
		}

Michael Yang's avatar
Michael Yang committed
518
519
520
521
522
		seq.numDecoded++
		if seq.numDecoded > 1 {
			seq.generationDuration += time.Since(t)
		} else {
			seq.processingDuration += time.Since(t)
523
524
525
526
		}

		// if done processing the prompt, generate an embedding and return
		if seq.embeddingOnly {
527
			embed := s.lc.GetEmbeddingsSeq(seq.cache.Id)
528
529
530
531
532
			if embed == nil {
				embed = s.lc.GetEmbeddingsIth(seq.iBatch)
			}

			seq.embedding <- embed
533
			s.removeSequence(i, llm.DoneReasonStop)
534
535
536
537
			continue
		}

		// sample a token
538
539
		token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
		seq.samplingCtx.Accept(token, true)
540
541
542
543
544
545
546
547
548
549
		piece := s.model.TokenToPiece(token)

		seq.numPredicted++

		// if it's an end of sequence token, break
		if s.model.TokenIsEog(token) {
			// TODO (jmorganca): we should send this back
			// as it's important for the /api/generate context
			// seq.responses <- piece

550
			s.removeSequence(i, llm.DoneReasonStop)
551
552
553
			continue
		}

554
555
556
557
558
559
560
561
562
		// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
		if seq.logprobs {
			logits := s.lc.GetLogitsIth(seq.iBatch)
			if logits != nil {
				logprobs := calculateLogprobsLlama(logits, token, seq.topLogprobs, s.model)
				seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
			}
		}

563
564
565
566
567
		seq.inputs = []input{{token: token}}

		seq.pendingResponses = append(seq.pendingResponses, piece)
		sequence := strings.Join(seq.pendingResponses, "")

Jesse Gross's avatar
Jesse Gross committed
568
		if ok, stop := common.FindStop(sequence, seq.stop); ok {
569
570
571
572
			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)

			var tokenTruncated bool
			origLen := len(seq.pendingResponses)
Jesse Gross's avatar
Jesse Gross committed
573
			seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop)
574
575
			newLen := len(seq.pendingResponses)

576
577
578
579
580
581
582
583
584
585
586
			// Truncate logprobs to match the truncated responses
			if seq.logprobs {
				origLogprobsLen := len(seq.pendingLogprobs)
				numTokensRemoved := origLen - newLen
				newLogprobsLen := origLogprobsLen - numTokensRemoved
				if newLogprobsLen < 0 {
					newLogprobsLen = 0
				}
				seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen]
			}

587
588
589
590
591
592
593
594
595
596
597
598
599
			// Update the cache based on the tokens that will be returned:
			// - We have 1 token more than is currently in the cache because
			// the last one generated wasn't submitted to Decode
			// - Remove any stop sequences that we stripped out
			// - If truncateStop removed a portion of a token, drop that
			// - As defense-in-depth, if truncatedToken didn't find a stop token
			// remove the extra one that we added to the cache len
			tokenLen := len(seq.cache.Inputs) + 1
			tokenLen -= origLen - newLen
			if tokenTruncated || origLen == newLen {
				tokenLen--
			}
			seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
600

601
			s.removeSequence(i, llm.DoneReasonStop)
602
603
604
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
605
		if common.ContainsStopSuffix(sequence, seq.stop) {
606
607
608
			continue
		}

Jesse Gross's avatar
Jesse Gross committed
609
		if common.IncompleteUnicode(sequence) {
610
611
612
613
			continue
		}

		if !flushPending(seq) {
614
			s.removeSequence(i, llm.DoneReasonConnectionClosed)
615
616
		}
	}
617
618

	return nil
619
620
621
}

func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
622
	var req llm.CompletionRequest
623
624
625
626
627
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, "Bad request", http.StatusBadRequest)
		return
	}

628
629
630
631
632
	if req.Options == nil {
		opts := api.DefaultOptions()
		req.Options = &opts
	}

633
634
635
636
637
638
639
640
641
642
	// Set the headers to indicate streaming
	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Transfer-Encoding", "chunked")

	flusher, ok := w.(http.Flusher)
	if !ok {
		http.Error(w, "Streaming not supported", http.StatusInternalServerError)
		return
	}

643
644
645
646
647
648
649
650
651
652
653
654
655
656
	// Extract options from the CompletionRequest
	samplingParams := llama.SamplingParams{
		TopK:           req.Options.TopK,
		TopP:           req.Options.TopP,
		MinP:           req.Options.MinP,
		TypicalP:       req.Options.TypicalP,
		Temp:           req.Options.Temperature,
		RepeatLastN:    req.Options.RepeatLastN,
		PenaltyRepeat:  req.Options.RepeatPenalty,
		PenaltyFreq:    req.Options.FrequencyPenalty,
		PenaltyPresent: req.Options.PresencePenalty,
		Seed:           uint32(req.Options.Seed),
		Grammar:        req.Grammar,
	}
657
658

	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
659
660
661
		numPredict:     req.Options.NumPredict,
		stop:           req.Options.Stop,
		numKeep:        req.Options.NumKeep,
662
663
		samplingParams: &samplingParams,
		embedding:      false,
664
665
		shift:          req.Shift,
		truncate:       req.Truncate,
666
667
		logprobs:       req.Logprobs,
		topLogprobs:    req.TopLogprobs,
668
669
	})
	if err != nil {
670
671
672
673
		if errors.Is(err, errorInputTooLong) {
			http.Error(w, err.Error(), http.StatusBadRequest)
			return
		}
674
675
676
677
		http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
		return
	}

678
	// Ensure there is a place to put the sequence, released when removed from s.seqs
679
	if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
680
681
682
		if errors.Is(err, context.Canceled) {
			slog.Info("aborting completion request due to client closing the connection")
		} else {
683
			http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
684
		}
685
686
687
		return
	}

688
	s.mu.Lock()
689
	found := false
690
691
	for i, sq := range s.seqs {
		if sq == nil {
692
			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
693
694
			if err != nil {
				s.mu.Unlock()
695
				s.seqsSem.Release(1)
696
697
698
				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
				return
			}
699

700
701
			s.seqs[i] = seq
			s.cond.Signal()
702
			found = true
703
704
705
706
707
			break
		}
	}
	s.mu.Unlock()

708
	if !found {
709
		s.seqsSem.Release(1)
710
711
712
713
		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
		return
	}

714
715
716
717
718
	for {
		select {
		case <-r.Context().Done():
			close(seq.quit)
			return
719
		case resp, ok := <-seq.responses:
720
			if ok {
721
				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
722
723
					Content:  resp.content,
					Logprobs: resp.logprobs,
724
725
726
727
728
729
730
731
				}); err != nil {
					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
					close(seq.quit)
					return
				}

				flusher.Flush()
			} else {
732
733
				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
					Done:               true,
734
					DoneReason:         seq.doneReason,
735
					PromptEvalCount:    seq.numPromptInputs,
Michael Yang's avatar
Michael Yang committed
736
					PromptEvalDuration: seq.processingDuration,
737
					EvalCount:          seq.numDecoded,
Michael Yang's avatar
Michael Yang committed
738
					EvalDuration:       seq.generationDuration,
739
740
741
742
743
744
745
746
747
748
749
				}); err != nil {
					http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
				}

				return
			}
		}
	}
}

func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
750
	var req llm.EmbeddingRequest
751
752
753
754
755
756
757
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
		return
	}

	w.Header().Set("Content-Type", "application/json")

758
759
	seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
		embedding: true,
760
761
762
763
764

		// TODO (jmorganca): this should be provided by the server via the
		// request options and truncated here in the runner, instead of relying on
		// the server's truncate logic
		truncate: true,
765
	})
766
767
768
769
770
	if err != nil {
		http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
		return
	}

771
	// Ensure there is a place to put the sequence, released when removed from s.seqs
772
	if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
773
774
775
		if errors.Is(err, context.Canceled) {
			slog.Info("aborting embeddings request due to client closing the connection")
		} else {
776
			http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
777
		}
778
779
780
		return
	}

781
	s.mu.Lock()
782
	found := false
783
784
	for i, sq := range s.seqs {
		if sq == nil {
785
			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
786
787
			if err != nil {
				s.mu.Unlock()
788
				s.seqsSem.Release(1)
789
790
791
792
793
				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
				return
			}
			s.seqs[i] = seq
			s.cond.Signal()
794
			found = true
795
796
797
798
799
			break
		}
	}
	s.mu.Unlock()

800
	if !found {
801
		s.seqsSem.Release(1)
802
803
804
805
		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
		return
	}

806
807
	embedding := <-seq.embedding

808
	if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
809
		Embedding: embedding,
810
811
812
813
814
815
816
	}); err != nil {
		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
	}
}

func (s *Server) health(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "application/json")
817
818
	if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
		Status:   s.status,
819
820
821
822
823
824
		Progress: s.progress,
	}); err != nil {
		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
	}
}

Jesse Gross's avatar
Jesse Gross committed
825
826
// loadModel allocates memory based on the given parameters and loads the weights. The
// memory allocated is worst case for text models but not for vision.
827
828
829
func (s *Server) loadModel(
	params llama.ModelParams,
	mpath string,
Jesse Gross's avatar
Jesse Gross committed
830
	lpath []string,
831
832
	ppath string,
	kvSize int,
833
	kvCacheType string,
834
835
836
837
	flashAttention bool,
	threads int,
	multiUserCache bool,
) {
838
839
840
841
842
	var err error
	s.model, err = llama.LoadModelFromFile(mpath, params)
	if err != nil {
		panic(err)
	}
843

844
	ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType)
845
846
847
848
	s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
	if err != nil {
		panic(err)
	}
849

Jesse Gross's avatar
Jesse Gross committed
850
851
852
853
	for _, path := range lpath {
		err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads)
		if err != nil {
			panic(err)
854
855
856
857
		}
	}

	if ppath != "" {
858
		var err error
859
		s.image, err = NewImageContext(s.lc, ppath)
860
861
862
		if err != nil {
			panic(err)
		}
863
864
	}

865
866
867
868
	s.cache, err = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache)
	if err != nil {
		panic(err)
	}
869

870
	s.status = llm.ServerStatusReady
871
872
873
	s.ready.Done()
}

Jesse Gross's avatar
Jesse Gross committed
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
// load is the handler called by the Ollama server to process different
// load operations
func (s *Server) load(w http.ResponseWriter, r *http.Request) {
	s.loadMu.Lock()
	defer s.loadMu.Unlock()

	w.Header().Set("Content-Type", "application/json")

	if s.status != llm.ServerStatusLaunched {
		http.Error(w, "model already loaded", http.StatusInternalServerError)
		return
	}

	var req llm.LoadRequest
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, "bad request", http.StatusBadRequest)
		return
	}

	slog.Info("load", "request", req)

	switch req.Operation {
	// LoadOperationFit and LoadOperationAlloc have no meaning here - just return a successful response

	case llm.LoadOperationCommit:
		s.batchSize = req.BatchSize
		s.parallel = req.Parallel
		s.seqs = make([]*Sequence, s.parallel)
		s.seqsSem = semaphore.NewWeighted(int64(s.parallel))

		numGPU := 0
905
906
907
908
909
910
911
912
		var tensorSplit []float32
		var llamaIDs []uint64

		gpuIDs := llama.EnumerateGPUs()
		sort.Sort(req.GPULayers)
		for _, layers := range req.GPULayers {
			for i := range gpuIDs {
				if gpuIDs[i].DeviceID == layers.DeviceID {
Jesse Gross's avatar
Jesse Gross committed
913
					numGPU += len(layers.Layers)
914
915
					tensorSplit = append(tensorSplit, float32(len(layers.Layers)))
					llamaIDs = append(llamaIDs, gpuIDs[i].LlamaID)
Jesse Gross's avatar
Jesse Gross committed
916
917
918
919
920
				}
			}
		}

		params := llama.ModelParams{
921
			Devices:      llamaIDs,
Jesse Gross's avatar
Jesse Gross committed
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
			NumGpuLayers: numGPU,
			MainGpu:      req.MainGPU,
			UseMmap:      req.UseMmap && len(req.LoraPath) == 0,
			TensorSplit:  tensorSplit,
			Progress: func(progress float32) {
				s.progress = progress
			},
		}

		s.status = llm.ServerStatusLoadingModel
		go s.loadModel(params, s.modelPath, req.LoraPath, req.ProjectorPath, req.KvSize, req.KvCacheType, req.FlashAttention, req.NumThreads, req.MultiUserCache)

	case llm.LoadOperationClose:
		// No-op for us
		if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil {
			http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
		}
		return
	}

	resp := llm.LoadResponse{Success: true}
	if err := json.NewEncoder(w).Encode(&resp); err != nil {
		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
		return
	}
}

949
950
951
952
func Execute(args []string) error {
	fs := flag.NewFlagSet("runner", flag.ExitOnError)
	mpath := fs.String("model", "", "Path to model binary file")
	port := fs.Int("port", 8080, "Port to expose the server on")
953
	_ = fs.Bool("verbose", false, "verbose output (default: disabled)")
954

955
956
957
958
959
960
	fs.Usage = func() {
		fmt.Fprintf(fs.Output(), "Runner usage\n")
		fs.PrintDefaults()
	}
	if err := fs.Parse(args); err != nil {
		return err
961
	}
962
	slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
963
	slog.Info("starting go runner")
964
965

	llama.BackendInit()
966
967

	server := &Server{
Jesse Gross's avatar
Jesse Gross committed
968
969
		modelPath: *mpath,
		status:    llm.ServerStatusLaunched,
970
971
972
973
974
975
976
	}

	server.ready.Add(1)

	server.cond = sync.NewCond(&server.mu)

	ctx, cancel := context.WithCancel(context.Background())
Michael Yang's avatar
Michael Yang committed
977
978
	defer cancel()

979
980
981
982
983
984
	go server.run(ctx)

	addr := "127.0.0.1:" + strconv.Itoa(*port)
	listener, err := net.Listen("tcp", addr)
	if err != nil {
		fmt.Println("Listen error:", err)
985
		return err
986
987
988
989
	}
	defer listener.Close()

	mux := http.NewServeMux()
Jesse Gross's avatar
Jesse Gross committed
990
	mux.HandleFunc("POST /load", server.load)
991
992
993
994
995
996
997
998
999
1000
1001
	mux.HandleFunc("/embedding", server.embeddings)
	mux.HandleFunc("/completion", server.completion)
	mux.HandleFunc("/health", server.health)

	httpServer := http.Server{
		Handler: mux,
	}

	log.Println("Server listening on", addr)
	if err := httpServer.Serve(listener); err != nil {
		log.Fatal("server error:", err)
1002
		return err
1003
1004
	}

1005
	return nil
1006
}