Unverified Commit e53b3cbd authored by Bruce MacDonald's avatar Bruce MacDonald Committed by GitHub
Browse files

llm: set done reason at server level (#9830)

No functional change. Many different done reasons can be set at the runner
level, so rather than obsuring them we should return them to the server
process and let it choose what to do with the done reason. This separates
the API concerns from the runner.
parent b51e0f39
...@@ -675,9 +675,32 @@ type CompletionRequest struct { ...@@ -675,9 +675,32 @@ type CompletionRequest struct {
Grammar string // set before sending the request to the subprocess Grammar string // set before sending the request to the subprocess
} }
// DoneReason represents the reason why a completion response is done
type DoneReason int
const (
// DoneReasonStop indicates the completion stopped naturally
DoneReasonStop DoneReason = iota
// DoneReasonLength indicates the completion stopped due to length limits
DoneReasonLength
// DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed
DoneReasonConnectionClosed
)
func (d DoneReason) String() string {
switch d {
case DoneReasonLength:
return "length"
case DoneReasonStop:
return "stop"
default:
return "" // closed
}
}
type CompletionResponse struct { type CompletionResponse struct {
Content string `json:"content"` Content string `json:"content"`
DoneReason string `json:"done_reason"` DoneReason DoneReason `json:"done_reason"`
Done bool `json:"done"` Done bool `json:"done"`
PromptEvalCount int `json:"prompt_eval_count"` PromptEvalCount int `json:"prompt_eval_count"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration"` PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
...@@ -786,7 +809,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu ...@@ -786,7 +809,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
continue continue
} }
// slog.Debug("got line", "line", string(line))
evt, ok := bytes.CutPrefix(line, []byte("data: ")) evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok { if !ok {
evt = line evt = line
......
...@@ -83,7 +83,7 @@ type Sequence struct { ...@@ -83,7 +83,7 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation // true if an embedding are to be returned instead of text generation
embeddingOnly bool embeddingOnly bool
doneReason string doneReason llm.DoneReason
// Metrics // Metrics
startProcessingTime time.Time startProcessingTime time.Time
...@@ -301,7 +301,7 @@ func flushPending(seq *Sequence) bool { ...@@ -301,7 +301,7 @@ func flushPending(seq *Sequence) bool {
} }
} }
func (s *Server) removeSequence(seqIndex int, reason string) { func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
seq := s.seqs[seqIndex] seq := s.seqs[seqIndex]
flushPending(seq) flushPending(seq)
...@@ -380,7 +380,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -380,7 +380,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
// if past the num predict limit // if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, "limit") s.removeSequence(seqIdx, llm.DoneReasonLength)
continue continue
} }
...@@ -482,7 +482,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -482,7 +482,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
} }
seq.embedding <- embed seq.embedding <- embed
s.removeSequence(i, "") s.removeSequence(i, llm.DoneReasonStop)
continue continue
} }
...@@ -499,7 +499,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -499,7 +499,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
// as it's important for the /api/generate context // as it's important for the /api/generate context
// seq.responses <- piece // seq.responses <- piece
s.removeSequence(i, "stop") s.removeSequence(i, llm.DoneReasonStop)
continue continue
} }
...@@ -530,7 +530,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -530,7 +530,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
} }
seq.cache.Inputs = seq.cache.Inputs[:tokenLen] seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
s.removeSequence(i, "stop") s.removeSequence(i, llm.DoneReasonStop)
continue continue
} }
...@@ -543,7 +543,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -543,7 +543,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
} }
if !flushPending(seq) { if !flushPending(seq) {
s.removeSequence(i, "connection") s.removeSequence(i, llm.DoneReasonConnectionClosed)
} }
} }
...@@ -657,14 +657,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { ...@@ -657,14 +657,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush() flusher.Flush()
} else { } else {
// Send the final response
doneReason := "stop"
if seq.doneReason == "limit" {
doneReason = "length"
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true, Done: true,
DoneReason: doneReason, DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs, PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numDecoded, EvalCount: seq.numDecoded,
......
...@@ -82,7 +82,7 @@ type Sequence struct { ...@@ -82,7 +82,7 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation // true if an embedding are to be returned instead of text generation
embeddingOnly bool embeddingOnly bool
doneReason string doneReason llm.DoneReason
// Metrics // Metrics
startProcessingTime time.Time startProcessingTime time.Time
...@@ -341,7 +341,7 @@ func flushPending(seq *Sequence) bool { ...@@ -341,7 +341,7 @@ func flushPending(seq *Sequence) bool {
} }
} }
func (s *Server) removeSequence(seqIndex int, reason string) { func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
seq := s.seqs[seqIndex] seq := s.seqs[seqIndex]
flushPending(seq) flushPending(seq)
...@@ -391,7 +391,7 @@ func (s *Server) processBatch() error { ...@@ -391,7 +391,7 @@ func (s *Server) processBatch() error {
// if past the num predict limit // if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, "limit") s.removeSequence(seqIdx, llm.DoneReasonLength)
continue continue
} }
...@@ -510,7 +510,7 @@ func (s *Server) processBatch() error { ...@@ -510,7 +510,7 @@ func (s *Server) processBatch() error {
if seq.embeddingOnly { if seq.embeddingOnly {
// TODO(jessegross): Embedding support // TODO(jessegross): Embedding support
slog.Warn("generation of embedding outputs not yet supported") slog.Warn("generation of embedding outputs not yet supported")
s.removeSequence(i, "") s.removeSequence(i, llm.DoneReasonStop)
continue continue
} }
...@@ -528,7 +528,7 @@ func (s *Server) processBatch() error { ...@@ -528,7 +528,7 @@ func (s *Server) processBatch() error {
// as it's important for the /api/generate context // as it's important for the /api/generate context
// seq.responses <- piece // seq.responses <- piece
s.removeSequence(i, "stop") s.removeSequence(i, llm.DoneReasonStop)
continue continue
} }
...@@ -564,7 +564,7 @@ func (s *Server) processBatch() error { ...@@ -564,7 +564,7 @@ func (s *Server) processBatch() error {
} }
seq.cache.Inputs = seq.cache.Inputs[:tokenLen] seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
s.removeSequence(i, "stop") s.removeSequence(i, llm.DoneReasonStop)
continue continue
} }
...@@ -577,7 +577,7 @@ func (s *Server) processBatch() error { ...@@ -577,7 +577,7 @@ func (s *Server) processBatch() error {
} }
if !flushPending(seq) { if !flushPending(seq) {
s.removeSequence(i, "connection") s.removeSequence(i, llm.DoneReasonConnectionClosed)
} }
} }
...@@ -690,14 +690,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { ...@@ -690,14 +690,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush() flusher.Flush()
} else { } else {
// Send the final response
doneReason := "stop"
if seq.doneReason == "limit" {
doneReason = "length"
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true, Done: true,
DoneReason: doneReason, DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs, PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numPredicted, EvalCount: seq.numPredicted,
......
...@@ -308,11 +308,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -308,11 +308,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
Options: opts, Options: opts,
}, func(cr llm.CompletionResponse) { }, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{ res := api.GenerateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Response: cr.Content, Response: cr.Content,
Done: cr.Done, Done: cr.Done,
DoneReason: cr.DoneReason,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount, PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration, PromptEvalDuration: cr.PromptEvalDuration,
...@@ -326,6 +325,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -326,6 +325,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
if cr.Done { if cr.Done {
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart) res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
...@@ -1533,11 +1533,10 @@ func (s *Server) ChatHandler(c *gin.Context) { ...@@ -1533,11 +1533,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
Options: opts, Options: opts,
}, func(r llm.CompletionResponse) { }, func(r llm.CompletionResponse) {
res := api.ChatResponse{ res := api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content}, Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done, Done: r.Done,
DoneReason: r.DoneReason,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: r.PromptEvalDuration,
...@@ -1547,6 +1546,7 @@ func (s *Server) ChatHandler(c *gin.Context) { ...@@ -1547,6 +1546,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
if r.Done { if r.Done {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart) res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
......
...@@ -58,7 +58,7 @@ func TestGenerateChat(t *testing.T) { ...@@ -58,7 +58,7 @@ func TestGenerateChat(t *testing.T) {
mock := mockRunner{ mock := mockRunner{
CompletionResponse: llm.CompletionResponse{ CompletionResponse: llm.CompletionResponse{
Done: true, Done: true,
DoneReason: "stop", DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1, PromptEvalCount: 1,
PromptEvalDuration: 1, PromptEvalDuration: 1,
EvalCount: 1, EvalCount: 1,
...@@ -401,7 +401,7 @@ func TestGenerateChat(t *testing.T) { ...@@ -401,7 +401,7 @@ func TestGenerateChat(t *testing.T) {
mock.CompletionResponse = llm.CompletionResponse{ mock.CompletionResponse = llm.CompletionResponse{
Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`, Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
Done: true, Done: true,
DoneReason: "done", DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1, PromptEvalCount: 1,
PromptEvalDuration: 1, PromptEvalDuration: 1,
EvalCount: 1, EvalCount: 1,
...@@ -519,7 +519,7 @@ func TestGenerateChat(t *testing.T) { ...@@ -519,7 +519,7 @@ func TestGenerateChat(t *testing.T) {
{ {
Content: `, WA","unit":"celsius"}}`, Content: `, WA","unit":"celsius"}}`,
Done: true, Done: true,
DoneReason: "tool_call", DoneReason: llm.DoneReasonStop,
PromptEvalCount: 3, PromptEvalCount: 3,
PromptEvalDuration: 1, PromptEvalDuration: 1,
}, },
...@@ -594,7 +594,7 @@ func TestGenerate(t *testing.T) { ...@@ -594,7 +594,7 @@ func TestGenerate(t *testing.T) {
mock := mockRunner{ mock := mockRunner{
CompletionResponse: llm.CompletionResponse{ CompletionResponse: llm.CompletionResponse{
Done: true, Done: true,
DoneReason: "stop", DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1, PromptEvalCount: 1,
PromptEvalDuration: 1, PromptEvalDuration: 1,
EvalCount: 1, EvalCount: 1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment