"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "b0135f4b9b176eab9155b660d04c9ca2a1ec2341"
Commit 05e08d23 authored by Michael Yang's avatar Michael Yang
Browse files

return more info in generate response

parent 31590284
package api package api
import "runtime" import (
"fmt"
"os"
"runtime"
"time"
)
type PullRequest struct { type PullRequest struct {
Model string `json:"model"` Model string `json:"model"`
...@@ -20,7 +25,41 @@ type GenerateRequest struct { ...@@ -20,7 +25,41 @@ type GenerateRequest struct {
} }
type GenerateResponse struct { type GenerateResponse struct {
Response string `json:"response"` Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response,omitempty"`
Done bool `json:"done"`
TotalDuration time.Duration `json:"total_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration time.Duration `json:"eval_duration,omitempty"`
}
func (r *GenerateResponse) Summary() {
if r.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration)
}
if r.PromptEvalCount > 0 {
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount)
}
if r.PromptEvalDuration > 0 {
fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration)
fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds())
}
if r.EvalCount > 0 {
fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", r.EvalCount)
}
if r.EvalDuration > 0 {
fmt.Fprintf(os.Stderr, "eval duraiton: %s\n", r.EvalDuration)
fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds())
}
} }
type Options struct { type Options struct {
......
...@@ -72,20 +72,20 @@ func pull(model string) error { ...@@ -72,20 +72,20 @@ func pull(model string) error {
) )
} }
func RunGenerate(_ *cobra.Command, args []string) error { func RunGenerate(cmd *cobra.Command, args []string) error {
if len(args) > 1 { if len(args) > 1 {
// join all args into a single prompt // join all args into a single prompt
return generate(args[0], strings.Join(args[1:], " ")) return generate(cmd, args[0], strings.Join(args[1:], " "))
} }
if term.IsTerminal(int(os.Stdin.Fd())) { if term.IsTerminal(int(os.Stdin.Fd())) {
return generateInteractive(args[0]) return generateInteractive(cmd, args[0])
} }
return generateBatch(args[0]) return generateBatch(cmd, args[0])
} }
func generate(model, prompt string) error { func generate(cmd *cobra.Command, model, prompt string) error {
if len(strings.TrimSpace(prompt)) > 0 { if len(strings.TrimSpace(prompt)) > 0 {
client := api.NewClient() client := api.NewClient()
...@@ -108,12 +108,16 @@ func generate(model, prompt string) error { ...@@ -108,12 +108,16 @@ func generate(model, prompt string) error {
} }
}() }()
var latest api.GenerateResponse
request := api.GenerateRequest{Model: model, Prompt: prompt} request := api.GenerateRequest{Model: model, Prompt: prompt}
fn := func(resp api.GenerateResponse) error { fn := func(resp api.GenerateResponse) error {
if !spinner.IsFinished() { if !spinner.IsFinished() {
spinner.Finish() spinner.Finish()
} }
latest = resp
fmt.Print(resp.Response) fmt.Print(resp.Response)
return nil return nil
} }
...@@ -124,16 +128,25 @@ func generate(model, prompt string) error { ...@@ -124,16 +128,25 @@ func generate(model, prompt string) error {
fmt.Println() fmt.Println()
fmt.Println() fmt.Println()
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return err
}
if verbose {
latest.Summary()
}
} }
return nil return nil
} }
func generateInteractive(model string) error { func generateInteractive(cmd *cobra.Command, model string) error {
fmt.Print(">>> ") fmt.Print(">>> ")
scanner := bufio.NewScanner(os.Stdin) scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() { for scanner.Scan() {
if err := generate(model, scanner.Text()); err != nil { if err := generate(cmd, model, scanner.Text()); err != nil {
return err return err
} }
...@@ -143,12 +156,12 @@ func generateInteractive(model string) error { ...@@ -143,12 +156,12 @@ func generateInteractive(model string) error {
return nil return nil
} }
func generateBatch(model string) error { func generateBatch(cmd *cobra.Command, model string) error {
scanner := bufio.NewScanner(os.Stdin) scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() { for scanner.Scan() {
prompt := scanner.Text() prompt := scanner.Text()
fmt.Printf(">>> %s\n", prompt) fmt.Printf(">>> %s\n", prompt)
if err := generate(model, prompt); err != nil { if err := generate(cmd, model, prompt); err != nil {
return err return err
} }
} }
...@@ -200,6 +213,8 @@ func NewCLI() *cobra.Command { ...@@ -200,6 +213,8 @@ func NewCLI() *cobra.Command {
RunE: RunRun, RunE: RunRun,
} }
runCmd.Flags().Bool("verbose", false, "Show timings for response")
serveCmd := &cobra.Command{ serveCmd := &cobra.Command{
Use: "serve", Use: "serve",
Aliases: []string{"start"}, Aliases: []string{"start"},
......
...@@ -79,9 +79,11 @@ llama_token llama_sample( ...@@ -79,9 +79,11 @@ llama_token llama_sample(
import "C" import "C"
import ( import (
"errors" "errors"
"fmt"
"io" "io"
"os" "os"
"strings" "strings"
"time"
"unsafe" "unsafe"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
...@@ -147,7 +149,7 @@ func (llm *llama) Close() { ...@@ -147,7 +149,7 @@ func (llm *llama) Close() {
C.llama_print_timings(llm.ctx) C.llama_print_timings(llm.ctx)
} }
func (llm *llama) Predict(prompt string, fn func(string)) error { func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error {
if tokens := llm.tokenize(prompt); tokens != nil { if tokens := llm.tokenize(prompt); tokens != nil {
return llm.generate(tokens, fn) return llm.generate(tokens, fn)
} }
...@@ -176,7 +178,7 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string { ...@@ -176,7 +178,7 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
return sb.String() return sb.String()
} }
func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error { func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
var opts C.struct_llama_sample_options var opts C.struct_llama_sample_options
opts.repeat_penalty = C.float(llm.RepeatPenalty) opts.repeat_penalty = C.float(llm.RepeatPenalty)
opts.frequency_penalty = C.float(llm.FrequencyPenalty) opts.frequency_penalty = C.float(llm.FrequencyPenalty)
...@@ -190,38 +192,58 @@ func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error { ...@@ -190,38 +192,58 @@ func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error {
opts.mirostat_tau = C.float(llm.MirostatTau) opts.mirostat_tau = C.float(llm.MirostatTau)
opts.mirostat_eta = C.float(llm.MirostatEta) opts.mirostat_eta = C.float(llm.MirostatEta)
pastTokens := deque[C.llama_token]{capacity: llm.RepeatLastN} output := deque[C.llama_token]{capacity: llm.NumCtx}
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) { for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 { if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
return errors.New("llama: eval") return errors.New("llama: eval")
} }
token, err := llm.sample(pastTokens, &opts) token, err := llm.sample(output, &opts)
switch { if errors.Is(err, io.EOF) {
case errors.Is(err, io.EOF): break
return nil } else if err != nil {
case err != nil:
return err return err
} }
fn(llm.detokenize(token)) // call the callback
fn(api.GenerateResponse{
Response: llm.detokenize(token),
})
output.PushLeft(token)
input = []C.llama_token{token}
}
tokens = []C.llama_token{token} dur := func(ms float64) time.Duration {
d, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil {
panic(err)
}
pastTokens.PushLeft(token) return d
} }
timings := C.llama_get_timings(llm.ctx)
fn(api.GenerateResponse{
Done: true,
PromptEvalCount: int(timings.n_p_eval),
PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
EvalCount: int(timings.n_eval),
EvalDuration: dur(float64(timings.t_eval_ms)),
})
return nil return nil
} }
func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) { func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
numVocab := int(C.llama_n_vocab(llm.ctx)) numVocab := int(C.llama_n_vocab(llm.ctx))
logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab) logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
candidates := make([]C.struct_llama_token_data, 0, numVocab) candidates := deque[C.struct_llama_token_data]{capacity: numVocab}
for i := 0; i < numVocab; i++ { for i := 0; i < candidates.Cap(); i++ {
candidates = append(candidates, C.llama_token_data{ candidates.PushLeft(C.struct_llama_token_data{
id: C.int(i), id: C.int(i),
logit: logits[i], logit: logits[i],
p: 0, p: 0,
...@@ -230,8 +252,8 @@ func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_s ...@@ -230,8 +252,8 @@ func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_s
token := C.llama_sample( token := C.llama_sample(
llm.ctx, llm.ctx,
unsafe.SliceData(candidates), C.ulong(len(candidates)), unsafe.SliceData(candidates.Data()), C.ulong(candidates.Len()),
unsafe.SliceData(pastTokens.Data()), C.ulong(pastTokens.Len()), unsafe.SliceData(output.Data()), C.ulong(output.Len()),
opts) opts)
if token != C.llama_token_eos() { if token != C.llama_token_eos() {
return token, nil return token, nil
......
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"path" "path"
"strings" "strings"
"text/template" "text/template"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/lithammer/fuzzysearch/fuzzy" "github.com/lithammer/fuzzysearch/fuzzy"
...@@ -35,6 +36,8 @@ func cacheDir() string { ...@@ -35,6 +36,8 @@ func cacheDir() string {
} }
func generate(c *gin.Context) { func generate(c *gin.Context) {
start := time.Now()
req := api.GenerateRequest{ req := api.GenerateRequest{
Options: api.DefaultOptions(), Options: api.DefaultOptions(),
} }
...@@ -81,8 +84,14 @@ func generate(c *gin.Context) { ...@@ -81,8 +84,14 @@ func generate(c *gin.Context) {
} }
defer llm.Close() defer llm.Close()
fn := func(s string) { fn := func(r api.GenerateResponse) {
ch <- api.GenerateResponse{Response: s} r.Model = req.Model
r.CreatedAt = time.Now().UTC()
if r.Done {
r.TotalDuration = time.Since(start)
}
ch <- r
} }
if err := llm.Predict(req.Prompt, fn); err != nil { if err := llm.Predict(req.Prompt, fn); err != nil {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment