Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
0632dff3
"...text-generation-inference.git" did not exist on "f063ebde103a78ee1fb6eafb66cb462526c0ce44"
Unverified
Commit
0632dff3
authored
Jan 30, 2024
by
Bruce MacDonald
Committed by
GitHub
Jan 30, 2024
Browse files
trim chat prompt based on llm context size (#1963)
parent
509e2dec
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
440 additions
and
57 deletions
+440
-57
server/images.go
server/images.go
+24
-27
server/images_test.go
server/images_test.go
+56
-28
server/routes.go
server/routes.go
+105
-2
server/routes_test.go
server/routes_test.go
+255
-0
No files found.
server/images.go
View file @
0632dff3
...
...
@@ -146,62 +146,59 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
return
Prompt
(
post
,
p
)
}
func
(
m
*
Model
)
ChatPrompt
(
msgs
[]
api
.
Message
)
(
string
,
[]
api
.
ImageData
,
error
)
{
type
ChatHistory
struct
{
Prompts
[]
PromptVars
CurrentImages
[]
api
.
ImageData
LastSystem
string
}
// ChatPrompts returns a list of formatted chat prompts from a list of messages
func
(
m
*
Model
)
ChatPrompts
(
msgs
[]
api
.
Message
)
(
*
ChatHistory
,
error
)
{
// build the prompt from the list of messages
var
prompt
strings
.
Builder
var
currentImages
[]
api
.
ImageData
var
lastSystem
string
currentVars
:=
PromptVars
{
First
:
true
,
System
:
m
.
System
,
}
writePrompt
:=
func
()
error
{
p
,
err
:=
Prompt
(
m
.
Template
,
currentVars
)
if
err
!=
nil
{
return
err
}
prompt
.
WriteString
(
p
)
currentVars
=
PromptVars
{}
return
nil
}
prompts
:=
[]
PromptVars
{}
for
_
,
msg
:=
range
msgs
{
switch
strings
.
ToLower
(
msg
.
Role
)
{
case
"system"
:
if
currentVars
.
System
!=
""
{
if
err
:=
writePrompt
();
err
!=
nil
{
return
""
,
nil
,
err
}
prompts
=
append
(
prompts
,
currentVars
)
currentVars
=
PromptVars
{}
}
currentVars
.
System
=
msg
.
Content
lastSystem
=
msg
.
Content
case
"user"
:
if
currentVars
.
Prompt
!=
""
{
if
err
:=
writePrompt
();
err
!=
nil
{
return
""
,
nil
,
err
}
prompts
=
append
(
prompts
,
currentVars
)
currentVars
=
PromptVars
{}
}
currentVars
.
Prompt
=
msg
.
Content
currentImages
=
msg
.
Images
case
"assistant"
:
currentVars
.
Response
=
msg
.
Content
if
err
:=
writePrompt
();
err
!=
nil
{
return
""
,
nil
,
err
}
prompts
=
append
(
prompts
,
currentVars
)
currentVars
=
PromptVars
{}
default
:
return
""
,
nil
,
fmt
.
Errorf
(
"invalid role: %s, role must be one of [system, user, assistant]"
,
msg
.
Role
)
return
nil
,
fmt
.
Errorf
(
"invalid role: %s, role must be one of [system, user, assistant]"
,
msg
.
Role
)
}
}
// Append the last set of vars if they are non-empty
if
currentVars
.
Prompt
!=
""
||
currentVars
.
System
!=
""
{
p
,
err
:=
m
.
PreResponsePrompt
(
currentVars
)
if
err
!=
nil
{
return
""
,
nil
,
fmt
.
Errorf
(
"pre-response template: %w"
,
err
)
}
prompt
.
WriteString
(
p
)
prompts
=
append
(
prompts
,
currentVars
)
}
return
prompt
.
String
(),
currentImages
,
nil
return
&
ChatHistory
{
Prompts
:
prompts
,
CurrentImages
:
currentImages
,
LastSystem
:
lastSystem
,
},
nil
}
type
ManifestV2
struct
{
...
...
server/images_test.go
View file @
0632dff3
package
server
import
(
"bytes"
"strings"
"testing"
...
...
@@ -233,12 +234,32 @@ func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
}
}
func
chatHistoryEqual
(
a
,
b
ChatHistory
)
bool
{
if
len
(
a
.
Prompts
)
!=
len
(
b
.
Prompts
)
{
return
false
}
if
len
(
a
.
CurrentImages
)
!=
len
(
b
.
CurrentImages
)
{
return
false
}
for
i
,
v
:=
range
a
.
Prompts
{
if
v
!=
b
.
Prompts
[
i
]
{
return
false
}
}
for
i
,
v
:=
range
a
.
CurrentImages
{
if
!
bytes
.
Equal
(
v
,
b
.
CurrentImages
[
i
])
{
return
false
}
}
return
a
.
LastSystem
==
b
.
LastSystem
}
func
TestChat
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
template
string
msgs
[]
api
.
Message
want
string
want
ChatHistory
wantErr
string
}{
{
...
...
@@ -254,30 +275,16 @@ func TestChat(t *testing.T) {
Content
:
"What are the potion ingredients?"
,
},
},
want
:
"[INST] You are a Wizard. What are the potion ingredients? [/INST]"
,
},
{
name
:
"First Message"
,
template
:
"[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]"
,
msgs
:
[]
api
.
Message
{
{
Role
:
"system"
,
Content
:
"You are a Wizard."
,
},
{
Role
:
"user"
,
Content
:
"What are the potion ingredients?"
,
},
{
Role
:
"assistant"
,
Content
:
"eye of newt"
,
},
{
Role
:
"user"
,
Content
:
"Anything else?"
,
want
:
ChatHistory
{
Prompts
:
[]
PromptVars
{
{
System
:
"You are a Wizard."
,
Prompt
:
"What are the potion ingredients?"
,
First
:
true
,
},
},
LastSystem
:
"You are a Wizard."
,
},
want
:
"[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]"
,
},
{
name
:
"Message History"
,
...
...
@@ -300,7 +307,20 @@ func TestChat(t *testing.T) {
Content
:
"Anything else?"
,
},
},
want
:
"[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]"
,
want
:
ChatHistory
{
Prompts
:
[]
PromptVars
{
{
System
:
"You are a Wizard."
,
Prompt
:
"What are the potion ingredients?"
,
Response
:
"sugar"
,
First
:
true
,
},
{
Prompt
:
"Anything else?"
,
},
},
LastSystem
:
"You are a Wizard."
,
},
},
{
name
:
"Assistant Only"
,
...
...
@@ -311,7 +331,14 @@ func TestChat(t *testing.T) {
Content
:
"everything nice"
,
},
},
want
:
"[INST] [/INST]everything nice"
,
want
:
ChatHistory
{
Prompts
:
[]
PromptVars
{
{
Response
:
"everything nice"
,
First
:
true
,
},
},
},
},
{
name
:
"Invalid Role"
,
...
...
@@ -330,7 +357,7 @@ func TestChat(t *testing.T) {
Template
:
tt
.
template
,
}
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
,
_
,
err
:=
m
.
ChatPrompt
(
tt
.
msgs
)
got
,
err
:=
m
.
ChatPrompt
s
(
tt
.
msgs
)
if
tt
.
wantErr
!=
""
{
if
err
==
nil
{
t
.
Errorf
(
"ChatPrompt() expected error, got nil"
)
...
...
@@ -338,9 +365,10 @@ func TestChat(t *testing.T) {
if
!
strings
.
Contains
(
err
.
Error
(),
tt
.
wantErr
)
{
t
.
Errorf
(
"ChatPrompt() error = %v, wantErr %v"
,
err
,
tt
.
wantErr
)
}
return
}
if
got
!=
tt
.
want
{
t
.
Errorf
(
"ChatPrompt() got = %v, want %v"
,
got
,
tt
.
want
)
if
!
chatHistoryEqual
(
*
got
,
tt
.
want
)
{
t
.
Errorf
(
"ChatPrompt() got = %
#
v, want %
#
v"
,
got
,
tt
.
want
)
}
})
}
...
...
server/routes.go
View file @
0632dff3
...
...
@@ -1121,11 +1121,16 @@ func ChatHandler(c *gin.Context) {
checkpointLoaded
:=
time
.
Now
()
prompt
,
images
,
err
:=
model
.
ChatPrompt
(
req
.
Messages
)
chat
,
err
:=
model
.
ChatPrompt
s
(
req
.
Messages
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
prompt
,
err
:=
trimmedPrompt
(
c
.
Request
.
Context
(),
chat
,
model
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
slog
.
Debug
(
fmt
.
Sprintf
(
"prompt: %s"
,
prompt
))
...
...
@@ -1164,7 +1169,7 @@ func ChatHandler(c *gin.Context) {
predictReq
:=
llm
.
PredictOpts
{
Prompt
:
prompt
,
Format
:
req
.
Format
,
Images
:
i
mages
,
Images
:
chat
.
CurrentI
mages
,
Options
:
opts
,
}
if
err
:=
loaded
.
runner
.
Predict
(
c
.
Request
.
Context
(),
predictReq
,
fn
);
err
!=
nil
{
...
...
@@ -1202,3 +1207,101 @@ func ChatHandler(c *gin.Context) {
streamResponse
(
c
,
ch
)
}
// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model
type
promptInfo
struct
{
vars
PromptVars
tokenLen
int
}
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
// while preserving the most recent system message.
func
trimmedPrompt
(
ctx
context
.
Context
,
chat
*
ChatHistory
,
model
*
Model
)
(
string
,
error
)
{
if
len
(
chat
.
Prompts
)
==
0
{
return
""
,
nil
}
var
promptsToAdd
[]
promptInfo
var
totalTokenLength
int
var
systemPromptIncluded
bool
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
for
i
:=
len
(
chat
.
Prompts
)
-
1
;
i
>=
0
;
i
--
{
promptText
,
err
:=
promptString
(
model
,
chat
.
Prompts
[
i
],
i
==
len
(
chat
.
Prompts
)
-
1
)
if
err
!=
nil
{
return
""
,
err
}
encodedTokens
,
err
:=
loaded
.
runner
.
Encode
(
ctx
,
promptText
)
if
err
!=
nil
{
return
""
,
err
}
if
totalTokenLength
+
len
(
encodedTokens
)
>
loaded
.
NumCtx
&&
i
!=
len
(
chat
.
Prompts
)
-
1
{
break
// reached max context length, stop adding more prompts
}
totalTokenLength
+=
len
(
encodedTokens
)
systemPromptIncluded
=
systemPromptIncluded
||
chat
.
Prompts
[
i
]
.
System
!=
""
promptsToAdd
=
append
(
promptsToAdd
,
promptInfo
{
vars
:
chat
.
Prompts
[
i
],
tokenLen
:
len
(
encodedTokens
)})
}
// ensure the system prompt is included, if not already
if
chat
.
LastSystem
!=
""
&&
!
systemPromptIncluded
{
var
err
error
promptsToAdd
,
err
=
includeSystemPrompt
(
ctx
,
chat
.
LastSystem
,
totalTokenLength
,
promptsToAdd
)
if
err
!=
nil
{
return
""
,
err
}
}
promptsToAdd
[
len
(
promptsToAdd
)
-
1
]
.
vars
.
First
=
true
// construct the final prompt string from the prompts which fit within the context window
var
result
string
for
i
,
prompt
:=
range
promptsToAdd
{
promptText
,
err
:=
promptString
(
model
,
prompt
.
vars
,
i
==
0
)
if
err
!=
nil
{
return
""
,
err
}
result
=
promptText
+
result
}
return
result
,
nil
}
// promptString applies the model template to the prompt
func
promptString
(
model
*
Model
,
vars
PromptVars
,
isMostRecent
bool
)
(
string
,
error
)
{
if
isMostRecent
{
p
,
err
:=
model
.
PreResponsePrompt
(
vars
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"pre-response template: %w"
,
err
)
}
return
p
,
nil
}
p
,
err
:=
Prompt
(
model
.
Template
,
vars
)
if
err
!=
nil
{
return
""
,
err
}
return
p
,
nil
}
// includeSystemPrompt adjusts the prompts to include the system prompt.
func
includeSystemPrompt
(
ctx
context
.
Context
,
systemPrompt
string
,
totalTokenLength
int
,
promptsToAdd
[]
promptInfo
)
([]
promptInfo
,
error
)
{
systemTokens
,
err
:=
loaded
.
runner
.
Encode
(
ctx
,
systemPrompt
)
if
err
!=
nil
{
return
nil
,
err
}
for
i
:=
len
(
promptsToAdd
)
-
1
;
i
>=
0
;
i
--
{
if
totalTokenLength
+
len
(
systemTokens
)
<=
loaded
.
NumCtx
{
promptsToAdd
[
i
]
.
vars
.
System
=
systemPrompt
return
promptsToAdd
[
:
i
+
1
],
nil
}
totalTokenLength
-=
promptsToAdd
[
i
]
.
tokenLen
}
// if got here, system did not fit anywhere, so return the most recent prompt with the system message set
recent
:=
promptsToAdd
[
len
(
promptsToAdd
)
-
1
]
recent
.
vars
.
System
=
systemPrompt
return
[]
promptInfo
{
recent
},
nil
}
server/routes_test.go
View file @
0632dff3
...
...
@@ -16,6 +16,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/version"
)
...
...
@@ -239,3 +240,257 @@ func Test_Routes(t *testing.T) {
}
}
func
Test_ChatPrompt
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
template
string
chat
*
ChatHistory
numCtx
int
runner
MockLLM
want
string
wantErr
string
}{
{
name
:
"Single Message"
,
template
:
"[INST] {{ .System }} {{ .Prompt }} [/INST]"
,
chat
:
&
ChatHistory
{
Prompts
:
[]
PromptVars
{
{
System
:
"You are a Wizard."
,
Prompt
:
"What are the potion ingredients?"
,
First
:
true
,
},
},
LastSystem
:
"You are a Wizard."
,
},
numCtx
:
1
,
runner
:
MockLLM
{
encoding
:
[]
int
{
1
},
// fit the ctxLen
},
want
:
"[INST] You are a Wizard. What are the potion ingredients? [/INST]"
,
},
{
name
:
"First Message"
,
template
:
"[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]"
,
chat
:
&
ChatHistory
{
Prompts
:
[]
PromptVars
{
{
System
:
"You are a Wizard."
,
Prompt
:
"What are the potion ingredients?"
,
Response
:
"eye of newt"
,
First
:
true
,
},
{
Prompt
:
"Anything else?"
,
},
},
LastSystem
:
"You are a Wizard."
,
},
numCtx
:
2
,
runner
:
MockLLM
{
encoding
:
[]
int
{
1
},
// fit the ctxLen
},
want
:
"[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]"
,
},
{
name
:
"Message History"
,
template
:
"[INST] {{ .System }} {{ .Prompt }} [/INST]"
,
chat
:
&
ChatHistory
{
Prompts
:
[]
PromptVars
{
{
System
:
"You are a Wizard."
,
Prompt
:
"What are the potion ingredients?"
,
Response
:
"sugar"
,
First
:
true
,
},
{
Prompt
:
"Anything else?"
,
},
},
LastSystem
:
"You are a Wizard."
,
},
numCtx
:
4
,
runner
:
MockLLM
{
encoding
:
[]
int
{
1
},
// fit the ctxLen, 1 for each message
},
want
:
"[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]"
,
},
{
name
:
"Assistant Only"
,
template
:
"[INST] {{ .System }} {{ .Prompt }} [/INST]"
,
chat
:
&
ChatHistory
{
Prompts
:
[]
PromptVars
{
{
Response
:
"everything nice"
,
First
:
true
,
},
},
},
numCtx
:
1
,
runner
:
MockLLM
{
encoding
:
[]
int
{
1
},
},
want
:
"[INST] [/INST]everything nice"
,
},
{
name
:
"Message History Truncated, No System"
,
template
:
"[INST] {{ .System }} {{ .Prompt }} [/INST]"
,
chat
:
&
ChatHistory
{
Prompts
:
[]
PromptVars
{
{
Prompt
:
"What are the potion ingredients?"
,
Response
:
"sugar"
,
First
:
true
,
},
{
Prompt
:
"Anything else?"
,
Response
:
"spice"
,
},
{
Prompt
:
"... and?"
,
},
},
},
numCtx
:
2
,
// only 1 message from history and most recent message
runner
:
MockLLM
{
encoding
:
[]
int
{
1
},
},
want
:
"[INST] Anything else? [/INST]spice[INST] ... and? [/INST]"
,
},
{
name
:
"System is Preserved when Truncated"
,
template
:
"[INST] {{ .System }} {{ .Prompt }} [/INST]"
,
chat
:
&
ChatHistory
{
Prompts
:
[]
PromptVars
{
{
Prompt
:
"What are the magic words?"
,
Response
:
"abracadabra"
,
},
{
Prompt
:
"What is the spell for invisibility?"
,
},
},
LastSystem
:
"You are a wizard."
,
},
numCtx
:
2
,
runner
:
MockLLM
{
encoding
:
[]
int
{
1
},
},
want
:
"[INST] You are a wizard. What is the spell for invisibility? [/INST]"
,
},
{
name
:
"System is Preserved when Length Exceeded"
,
template
:
"[INST] {{ .System }} {{ .Prompt }} [/INST]"
,
chat
:
&
ChatHistory
{
Prompts
:
[]
PromptVars
{
{
Prompt
:
"What are the magic words?"
,
Response
:
"abracadabra"
,
},
{
Prompt
:
"What is the spell for invisibility?"
,
},
},
LastSystem
:
"You are a wizard."
,
},
numCtx
:
1
,
runner
:
MockLLM
{
encoding
:
[]
int
{
1
},
},
want
:
"[INST] You are a wizard. What is the spell for invisibility? [/INST]"
,
},
{
name
:
"First is Preserved when Truncated"
,
template
:
"[INST] {{ if .First }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]"
,
chat
:
&
ChatHistory
{
Prompts
:
[]
PromptVars
{
// first message omitted for test
{
Prompt
:
"Do you have a magic hat?"
,
Response
:
"Of course."
,
},
{
Prompt
:
"What is the spell for invisibility?"
,
},
},
LastSystem
:
"You are a wizard."
,
},
numCtx
:
3
,
// two most recent messages and room for system message
runner
:
MockLLM
{
encoding
:
[]
int
{
1
},
},
want
:
"[INST] You are a wizard. Do you have a magic hat? [/INST]Of course.[INST] What is the spell for invisibility? [/INST]"
,
},
{
name
:
"Most recent message is returned when longer than ctxLen"
,
template
:
"[INST] {{ .Prompt }} [/INST]"
,
chat
:
&
ChatHistory
{
Prompts
:
[]
PromptVars
{
{
Prompt
:
"What is the spell for invisibility?"
,
First
:
true
,
},
},
},
numCtx
:
1
,
// two most recent messages
runner
:
MockLLM
{
encoding
:
[]
int
{
1
,
2
},
},
want
:
"[INST] What is the spell for invisibility? [/INST]"
,
},
}
for
_
,
testCase
:=
range
tests
{
tt
:=
testCase
m
:=
&
Model
{
Template
:
tt
.
template
,
}
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
loaded
.
runner
=
&
tt
.
runner
loaded
.
Options
=
&
api
.
Options
{
Runner
:
api
.
Runner
{
NumCtx
:
tt
.
numCtx
,
},
}
got
,
err
:=
trimmedPrompt
(
context
.
Background
(),
tt
.
chat
,
m
)
if
tt
.
wantErr
!=
""
{
if
err
==
nil
{
t
.
Errorf
(
"ChatPrompt() expected error, got nil"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
tt
.
wantErr
)
{
t
.
Errorf
(
"ChatPrompt() error = %v, wantErr %v"
,
err
,
tt
.
wantErr
)
}
}
if
got
!=
tt
.
want
{
t
.
Errorf
(
"ChatPrompt() got = %v, want %v"
,
got
,
tt
.
want
)
}
})
}
}
type
MockLLM
struct
{
encoding
[]
int
}
func
(
llm
*
MockLLM
)
Predict
(
ctx
context
.
Context
,
pred
llm
.
PredictOpts
,
fn
func
(
llm
.
PredictResult
))
error
{
return
nil
}
func
(
llm
*
MockLLM
)
Encode
(
ctx
context
.
Context
,
prompt
string
)
([]
int
,
error
)
{
return
llm
.
encoding
,
nil
}
func
(
llm
*
MockLLM
)
Decode
(
ctx
context
.
Context
,
tokens
[]
int
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
llm
*
MockLLM
)
Embedding
(
ctx
context
.
Context
,
input
string
)
([]
float64
,
error
)
{
return
[]
float64
{},
nil
}
func
(
llm
*
MockLLM
)
Close
()
{
// do nothing
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment