Unverified Commit 6ee8c801 authored by Bruce MacDonald's avatar Bruce MacDonald Committed by GitHub
Browse files

restore model load duration on generate response (#1524)

* restore model load duration on generate response

- set model load duration on generate and chat done response
- calculate createAt time when response created

* remove checkpoints predict opts

* Update routes.go
parent 31f0551d
...@@ -551,14 +551,9 @@ type PredictOpts struct { ...@@ -551,14 +551,9 @@ type PredictOpts struct {
Prompt string Prompt string
Format string Format string
Images []api.ImageData Images []api.ImageData
CheckpointStart time.Time
CheckpointLoaded time.Time
} }
type PredictResult struct { type PredictResult struct {
CreatedAt time.Time
TotalDuration time.Duration
LoadDuration time.Duration
Content string Content string
Done bool Done bool
PromptEvalCount int PromptEvalCount int
...@@ -681,16 +676,12 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred ...@@ -681,16 +676,12 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
if p.Content != "" { if p.Content != "" {
fn(PredictResult{ fn(PredictResult{
CreatedAt: time.Now().UTC(),
Content: p.Content, Content: p.Content,
}) })
} }
if p.Stop { if p.Stop {
fn(PredictResult{ fn(PredictResult{
CreatedAt: time.Now().UTC(),
TotalDuration: time.Since(predict.CheckpointStart),
Done: true, Done: true,
PromptEvalCount: p.Timings.PromptN, PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
......
...@@ -261,12 +261,10 @@ func GenerateHandler(c *gin.Context) { ...@@ -261,12 +261,10 @@ func GenerateHandler(c *gin.Context) {
resp := api.GenerateResponse{ resp := api.GenerateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: r.CreatedAt, CreatedAt: time.Now().UTC(),
Done: r.Done, Done: r.Done,
Response: r.Content, Response: r.Content,
Metrics: api.Metrics{ Metrics: api.Metrics{
TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration,
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount, EvalCount: r.EvalCount,
...@@ -274,7 +272,11 @@ func GenerateHandler(c *gin.Context) { ...@@ -274,7 +272,11 @@ func GenerateHandler(c *gin.Context) {
}, },
} }
if r.Done && !req.Raw { if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String()) embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
if err != nil { if err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
...@@ -282,6 +284,7 @@ func GenerateHandler(c *gin.Context) { ...@@ -282,6 +284,7 @@ func GenerateHandler(c *gin.Context) {
} }
resp.Context = embd resp.Context = embd
} }
}
ch <- resp ch <- resp
} }
...@@ -290,8 +293,6 @@ func GenerateHandler(c *gin.Context) { ...@@ -290,8 +293,6 @@ func GenerateHandler(c *gin.Context) {
predictReq := llm.PredictOpts{ predictReq := llm.PredictOpts{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
Images: req.Images, Images: req.Images,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
...@@ -1012,11 +1013,9 @@ func ChatHandler(c *gin.Context) { ...@@ -1012,11 +1013,9 @@ func ChatHandler(c *gin.Context) {
resp := api.ChatResponse{ resp := api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: r.CreatedAt, CreatedAt: time.Now().UTC(),
Done: r.Done, Done: r.Done,
Metrics: api.Metrics{ Metrics: api.Metrics{
TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration,
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount, EvalCount: r.EvalCount,
...@@ -1024,7 +1023,10 @@ func ChatHandler(c *gin.Context) { ...@@ -1024,7 +1023,10 @@ func ChatHandler(c *gin.Context) {
}, },
} }
if !r.Done { if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} else {
resp.Message = &api.Message{Role: "assistant", Content: r.Content} resp.Message = &api.Message{Role: "assistant", Content: r.Content}
} }
...@@ -1035,8 +1037,6 @@ func ChatHandler(c *gin.Context) { ...@@ -1035,8 +1037,6 @@ func ChatHandler(c *gin.Context) {
predictReq := llm.PredictOpts{ predictReq := llm.PredictOpts{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
Images: images, Images: images,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment