Commit 1775647f authored by Michael Yang's avatar Michael Yang
Browse files

continue conversation

feed responses back into the llm
parent 77dc1a6d
...@@ -20,6 +20,7 @@ type PullProgress struct { ...@@ -20,6 +20,7 @@ type PullProgress struct {
type GenerateRequest struct { type GenerateRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Context []int `json:"context,omitempty"`
Options `json:"options"` Options `json:"options"`
} }
...@@ -30,6 +31,7 @@ type GenerateResponse struct { ...@@ -30,6 +31,7 @@ type GenerateResponse struct {
Response string `json:"response,omitempty"` Response string `json:"response,omitempty"`
Done bool `json:"done"` Done bool `json:"done"`
Context []int `json:"context,omitempty"`
TotalDuration time.Duration `json:"total_duration,omitempty"` TotalDuration time.Duration `json:"total_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"` PromptEvalCount int `json:"prompt_eval_count,omitempty"`
...@@ -104,7 +106,7 @@ func DefaultOptions() Options { ...@@ -104,7 +106,7 @@ func DefaultOptions() Options {
UseNUMA: false, UseNUMA: false,
NumCtx: 512, NumCtx: 2048,
NumBatch: 512, NumBatch: 512,
NumGPU: 1, NumGPU: 1,
LowVRAM: false, LowVRAM: false,
......
...@@ -85,6 +85,8 @@ func RunGenerate(cmd *cobra.Command, args []string) error { ...@@ -85,6 +85,8 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
return generateBatch(cmd, args[0]) return generateBatch(cmd, args[0])
} }
var generateContextKey struct{}
func generate(cmd *cobra.Command, model, prompt string) error { func generate(cmd *cobra.Command, model, prompt string) error {
if len(strings.TrimSpace(prompt)) > 0 { if len(strings.TrimSpace(prompt)) > 0 {
client := api.NewClient() client := api.NewClient()
...@@ -110,7 +112,12 @@ func generate(cmd *cobra.Command, model, prompt string) error { ...@@ -110,7 +112,12 @@ func generate(cmd *cobra.Command, model, prompt string) error {
var latest api.GenerateResponse var latest api.GenerateResponse
request := api.GenerateRequest{Model: model, Prompt: prompt} generateContext, ok := cmd.Context().Value(generateContextKey).([]int)
if !ok {
generateContext = []int{}
}
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
fn := func(resp api.GenerateResponse) error { fn := func(resp api.GenerateResponse) error {
if !spinner.IsFinished() { if !spinner.IsFinished() {
spinner.Finish() spinner.Finish()
...@@ -119,6 +126,8 @@ func generate(cmd *cobra.Command, model, prompt string) error { ...@@ -119,6 +126,8 @@ func generate(cmd *cobra.Command, model, prompt string) error {
latest = resp latest = resp
fmt.Print(resp.Response) fmt.Print(resp.Response)
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context))
return nil return nil
} }
......
...@@ -149,9 +149,14 @@ func (llm *llama) Close() { ...@@ -149,9 +149,14 @@ func (llm *llama) Close() {
C.llama_print_timings(llm.ctx) C.llama_print_timings(llm.ctx)
} }
func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error { func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
if tokens := llm.tokenize(prompt); tokens != nil { if input := llm.tokenize(prompt); input != nil {
return llm.generate(tokens, fn) embd := make([]C.llama_token, len(ctx))
for i := range ctx {
embd[i] = C.llama_token(ctx[i])
}
return llm.generate(append(embd, input...), fn)
} }
return errors.New("llama: tokenize") return errors.New("llama: tokenize")
...@@ -194,6 +199,11 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) ...@@ -194,6 +199,11 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
output := deque[C.llama_token]{capacity: llm.NumCtx} output := deque[C.llama_token]{capacity: llm.NumCtx}
context := deque[int]{capacity: llm.NumCtx / 2}
for _, in := range input {
context.PushLeft(int(in))
}
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) { for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 { if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
return errors.New("llama: eval") return errors.New("llama: eval")
...@@ -212,6 +222,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) ...@@ -212,6 +222,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
}) })
output.PushLeft(token) output.PushLeft(token)
context.PushLeft(int(token))
input = []C.llama_token{token} input = []C.llama_token{token}
} }
...@@ -228,6 +239,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) ...@@ -228,6 +239,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
timings := C.llama_get_timings(llm.ctx) timings := C.llama_get_timings(llm.ctx)
fn(api.GenerateResponse{ fn(api.GenerateResponse{
Done: true, Done: true,
Context: context.Data(),
PromptEvalCount: int(timings.n_p_eval), PromptEvalCount: int(timings.n_p_eval),
PromptEvalDuration: dur(float64(timings.t_p_eval_ms)), PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
EvalCount: int(timings.n_eval), EvalCount: int(timings.n_eval),
......
package main package main
import ( import (
"context"
"github.com/jmorganca/ollama/cmd" "github.com/jmorganca/ollama/cmd"
) )
func main() { func main() {
cmd.NewCLI().Execute() cmd.NewCLI().ExecuteContext(context.Background())
} }
...@@ -94,7 +94,7 @@ func generate(c *gin.Context) { ...@@ -94,7 +94,7 @@ func generate(c *gin.Context) {
ch <- r ch <- r
} }
if err := llm.Predict(req.Prompt, fn); err != nil { if err := llm.Predict(req.Context, req.Prompt, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
......
{{- if not .Context }}
Below is an instruction that describes a task. Write a response that appropriately completes the request. Below is an instruction that describes a task. Write a response that appropriately completes the request.
{{- end }}
### Instruction: ### Instruction:
{{ .Prompt }} {{ .Prompt }}
......
{{- if not .Context }}
A helpful assistant who helps the user with any questions asked. A helpful assistant who helps the user with any questions asked.
{{- end }}
User: {{ .Prompt }} User: {{ .Prompt }}
Assistant: Assistant:
{{- if not .Context }}
Below is an instruction that describes a task. Write a response that appropriately completes the request. Be concise. Once the request is completed, include no other text. Below is an instruction that describes a task. Write a response that appropriately completes the request. Be concise. Once the request is completed, include no other text.
{{- end }}
### Instruction: ### Instruction:
{{ .Prompt }} {{ .Prompt }}
### Response: ### Response:
{{- if not .Context }}
### System: ### System:
You are an AI assistant that follows instruction extremely well. Help as much as you can. You are an AI assistant that follows instruction extremely well. Help as much as you can.
{{- end }}
### User: ### User:
{{ .Prompt }} {{ .Prompt }}
......
{{ if not .Context }}
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
{{- end }}
USER: {{ .Prompt }} USER: {{ .Prompt }}
ASSISTANT: ASSISTANT:
{{- if not .Context }}
Below is an instruction that describes a task. Write a response that appropriately completes the request Below is an instruction that describes a task. Write a response that appropriately completes the request
{{- end }}
### Instruction: {{ .Prompt }} ### Instruction: {{ .Prompt }}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment