Unverified Commit 6ee8c801 authored by Bruce MacDonald's avatar Bruce MacDonald Committed by GitHub
Browse files

restore model load duration on generate response (#1524)

* restore model load duration on generate response

- set model load duration on generate and chat done response
- calculate createAt time when response created

* remove checkpoints predict opts

* Update routes.go
parent 31f0551d
...@@ -548,17 +548,12 @@ const maxBufferSize = 512 * format.KiloByte ...@@ -548,17 +548,12 @@ const maxBufferSize = 512 * format.KiloByte
const maxRetries = 6 const maxRetries = 6
type PredictOpts struct { type PredictOpts struct {
Prompt string Prompt string
Format string Format string
Images []api.ImageData Images []api.ImageData
CheckpointStart time.Time
CheckpointLoaded time.Time
} }
type PredictResult struct { type PredictResult struct {
CreatedAt time.Time
TotalDuration time.Duration
LoadDuration time.Duration
Content string Content string
Done bool Done bool
PromptEvalCount int PromptEvalCount int
...@@ -681,16 +676,12 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred ...@@ -681,16 +676,12 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
if p.Content != "" { if p.Content != "" {
fn(PredictResult{ fn(PredictResult{
CreatedAt: time.Now().UTC(), Content: p.Content,
Content: p.Content,
}) })
} }
if p.Stop { if p.Stop {
fn(PredictResult{ fn(PredictResult{
CreatedAt: time.Now().UTC(),
TotalDuration: time.Since(predict.CheckpointStart),
Done: true, Done: true,
PromptEvalCount: p.Timings.PromptN, PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
......
...@@ -261,12 +261,10 @@ func GenerateHandler(c *gin.Context) { ...@@ -261,12 +261,10 @@ func GenerateHandler(c *gin.Context) {
resp := api.GenerateResponse{ resp := api.GenerateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: r.CreatedAt, CreatedAt: time.Now().UTC(),
Done: r.Done, Done: r.Done,
Response: r.Content, Response: r.Content,
Metrics: api.Metrics{ Metrics: api.Metrics{
TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration,
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount, EvalCount: r.EvalCount,
...@@ -274,13 +272,18 @@ func GenerateHandler(c *gin.Context) { ...@@ -274,13 +272,18 @@ func GenerateHandler(c *gin.Context) {
}, },
} }
if r.Done && !req.Raw { if r.Done {
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String()) resp.TotalDuration = time.Since(checkpointStart)
if err != nil { resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
ch <- gin.H{"error": err.Error()}
return if !req.Raw {
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp.Context = embd
} }
resp.Context = embd
} }
ch <- resp ch <- resp
...@@ -288,11 +291,9 @@ func GenerateHandler(c *gin.Context) { ...@@ -288,11 +291,9 @@ func GenerateHandler(c *gin.Context) {
// Start prediction // Start prediction
predictReq := llm.PredictOpts{ predictReq := llm.PredictOpts{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: req.Format,
CheckpointStart: checkpointStart, Images: req.Images,
CheckpointLoaded: checkpointLoaded,
Images: req.Images,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
...@@ -1012,11 +1013,9 @@ func ChatHandler(c *gin.Context) { ...@@ -1012,11 +1013,9 @@ func ChatHandler(c *gin.Context) {
resp := api.ChatResponse{ resp := api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: r.CreatedAt, CreatedAt: time.Now().UTC(),
Done: r.Done, Done: r.Done,
Metrics: api.Metrics{ Metrics: api.Metrics{
TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration,
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount, EvalCount: r.EvalCount,
...@@ -1024,7 +1023,10 @@ func ChatHandler(c *gin.Context) { ...@@ -1024,7 +1023,10 @@ func ChatHandler(c *gin.Context) {
}, },
} }
if !r.Done { if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} else {
resp.Message = &api.Message{Role: "assistant", Content: r.Content} resp.Message = &api.Message{Role: "assistant", Content: r.Content}
} }
...@@ -1033,11 +1035,9 @@ func ChatHandler(c *gin.Context) { ...@@ -1033,11 +1035,9 @@ func ChatHandler(c *gin.Context) {
// Start prediction // Start prediction
predictReq := llm.PredictOpts{ predictReq := llm.PredictOpts{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: req.Format,
CheckpointStart: checkpointStart, Images: images,
CheckpointLoaded: checkpointLoaded,
Images: images,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment