Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
bfbf2f7c
Unverified
Commit
bfbf2f7c
authored
Feb 01, 2024
by
Michael Yang
Committed by
GitHub
Feb 01, 2024
Browse files
Merge pull request #2296 from ollama/mxyng/img-tags
append image tags to user content
parents
fe3cbd01
f3761405
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
89 additions
and
36 deletions
+89
-36
llm/dyn_ext_server.go
llm/dyn_ext_server.go
+3
-6
llm/llama.go
llm/llama.go
+1
-1
server/images.go
server/images.go
+17
-8
server/images_test.go
server/images_test.go
+26
-7
server/routes.go
server/routes.go
+40
-13
server/routes_test.go
server/routes_test.go
+2
-1
No files found.
llm/dyn_ext_server.go
View file @
bfbf2f7c
...
...
@@ -161,13 +161,10 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts
func
(
llm
*
dynExtServer
)
Predict
(
ctx
context
.
Context
,
predict
PredictOpts
,
fn
func
(
PredictResult
))
error
{
resp
:=
newExtServerResp
(
128
)
defer
freeExtServerResp
(
resp
)
var
imageData
[]
ImageData
if
len
(
predict
.
Images
)
>
0
{
for
cnt
,
i
:=
range
predict
.
Images
{
imageData
=
append
(
imageData
,
ImageData
{
Data
:
i
,
ID
:
cnt
})
}
slog
.
Info
(
fmt
.
Sprintf
(
"loaded %d images"
,
len
(
predict
.
Images
)))
}
slog
.
Info
(
fmt
.
Sprintf
(
"loaded %d images"
,
len
(
imageData
)))
request
:=
map
[
string
]
any
{
"prompt"
:
predict
.
Prompt
,
...
...
@@ -189,7 +186,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
"penalize_nl"
:
predict
.
Options
.
PenalizeNewline
,
"seed"
:
predict
.
Options
.
Seed
,
"stop"
:
predict
.
Options
.
Stop
,
"image_data"
:
imageData
,
"image_data"
:
predict
.
Images
,
"cache_prompt"
:
true
,
}
...
...
llm/llama.go
View file @
bfbf2f7c
...
...
@@ -62,7 +62,7 @@ const maxRetries = 3
type
PredictOpts
struct
{
Prompt
string
Format
string
Images
[]
api
.
ImageData
Images
[]
ImageData
Options
api
.
Options
}
...
...
server/images.go
View file @
bfbf2f7c
...
...
@@ -63,6 +63,7 @@ type PromptVars struct {
Prompt
string
Response
string
First
bool
Images
[]
llm
.
ImageData
}
// extractParts extracts the parts of the template before and after the {{.Response}} node.
...
...
@@ -147,15 +148,13 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
}
type
ChatHistory
struct
{
Prompts
[]
PromptVars
CurrentImages
[]
api
.
ImageData
LastSystem
string
Prompts
[]
PromptVars
LastSystem
string
}
// ChatPrompts returns a list of formatted chat prompts from a list of messages
func
(
m
*
Model
)
ChatPrompts
(
msgs
[]
api
.
Message
)
(
*
ChatHistory
,
error
)
{
// build the prompt from the list of messages
var
currentImages
[]
api
.
ImageData
lastSystem
:=
m
.
System
currentVars
:=
PromptVars
{
First
:
true
,
...
...
@@ -163,6 +162,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
}
prompts
:=
[]
PromptVars
{}
var
images
[]
llm
.
ImageData
for
_
,
msg
:=
range
msgs
{
switch
strings
.
ToLower
(
msg
.
Role
)
{
...
...
@@ -179,8 +179,18 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
prompts
=
append
(
prompts
,
currentVars
)
currentVars
=
PromptVars
{}
}
currentVars
.
Prompt
=
msg
.
Content
currentImages
=
msg
.
Images
for
i
:=
range
msg
.
Images
{
id
:=
len
(
images
)
+
i
currentVars
.
Prompt
+=
fmt
.
Sprintf
(
" [img-%d]"
,
id
)
currentVars
.
Images
=
append
(
currentVars
.
Images
,
llm
.
ImageData
{
ID
:
id
,
Data
:
msg
.
Images
[
i
],
})
}
images
=
append
(
images
,
currentVars
.
Images
...
)
case
"assistant"
:
currentVars
.
Response
=
msg
.
Content
prompts
=
append
(
prompts
,
currentVars
)
...
...
@@ -196,9 +206,8 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
}
return
&
ChatHistory
{
Prompts
:
prompts
,
CurrentImages
:
currentImages
,
LastSystem
:
lastSystem
,
Prompts
:
prompts
,
LastSystem
:
lastSystem
,
},
nil
}
...
...
server/images_test.go
View file @
bfbf2f7c
...
...
@@ -238,18 +238,37 @@ func chatHistoryEqual(a, b ChatHistory) bool {
if
len
(
a
.
Prompts
)
!=
len
(
b
.
Prompts
)
{
return
false
}
if
len
(
a
.
CurrentImages
)
!=
len
(
b
.
CurrentImages
)
{
return
false
}
for
i
,
v
:=
range
a
.
Prompts
{
if
v
!=
b
.
Prompts
[
i
]
{
if
v
.
First
!=
b
.
Prompts
[
i
]
.
First
{
return
false
}
}
for
i
,
v
:=
range
a
.
CurrentImages
{
if
!
bytes
.
Equal
(
v
,
b
.
CurrentImages
[
i
])
{
if
v
.
Response
!=
b
.
Prompts
[
i
]
.
Response
{
return
false
}
if
v
.
Prompt
!=
b
.
Prompts
[
i
]
.
Prompt
{
return
false
}
if
v
.
System
!=
b
.
Prompts
[
i
]
.
System
{
return
false
}
if
len
(
v
.
Images
)
!=
len
(
b
.
Prompts
[
i
]
.
Images
)
{
return
false
}
for
j
,
img
:=
range
v
.
Images
{
if
img
.
ID
!=
b
.
Prompts
[
i
]
.
Images
[
j
]
.
ID
{
return
false
}
if
!
bytes
.
Equal
(
img
.
Data
,
b
.
Prompts
[
i
]
.
Images
[
j
]
.
Data
)
{
return
false
}
}
}
return
a
.
LastSystem
==
b
.
LastSystem
}
...
...
server/routes.go
View file @
bfbf2f7c
...
...
@@ -244,6 +244,10 @@ func GenerateHandler(c *gin.Context) {
promptVars
.
System
=
model
.
System
}
for
i
:=
range
req
.
Images
{
promptVars
.
Prompt
+=
fmt
.
Sprintf
(
" [img-%d]"
,
i
)
}
p
,
err
:=
model
.
PreResponsePrompt
(
promptVars
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
...
...
@@ -308,11 +312,19 @@ func GenerateHandler(c *gin.Context) {
ch
<-
resp
}
var
images
[]
llm
.
ImageData
for
i
:=
range
req
.
Images
{
images
=
append
(
images
,
llm
.
ImageData
{
ID
:
i
,
Data
:
req
.
Images
[
i
],
})
}
// Start prediction
predictReq
:=
llm
.
PredictOpts
{
Prompt
:
prompt
,
Format
:
req
.
Format
,
Images
:
req
.
I
mages
,
Images
:
i
mages
,
Options
:
opts
,
}
if
err
:=
loaded
.
runner
.
Predict
(
c
.
Request
.
Context
(),
predictReq
,
fn
);
err
!=
nil
{
...
...
@@ -1139,7 +1151,8 @@ func ChatHandler(c *gin.Context) {
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
prompt
,
err
:=
trimmedPrompt
(
c
.
Request
.
Context
(),
chat
,
model
)
prompt
,
images
,
err
:=
trimmedPrompt
(
c
.
Request
.
Context
(),
chat
,
model
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
...
...
@@ -1182,7 +1195,7 @@ func ChatHandler(c *gin.Context) {
predictReq
:=
llm
.
PredictOpts
{
Prompt
:
prompt
,
Format
:
req
.
Format
,
Images
:
chat
.
CurrentI
mages
,
Images
:
i
mages
,
Options
:
opts
,
}
if
err
:=
loaded
.
runner
.
Predict
(
c
.
Request
.
Context
(),
predictReq
,
fn
);
err
!=
nil
{
...
...
@@ -1229,34 +1242,47 @@ type promptInfo struct {
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
// while preserving the most recent system message.
func
trimmedPrompt
(
ctx
context
.
Context
,
chat
*
ChatHistory
,
model
*
Model
)
(
string
,
error
)
{
func
trimmedPrompt
(
ctx
context
.
Context
,
chat
*
ChatHistory
,
model
*
Model
)
(
string
,
[]
llm
.
ImageData
,
error
)
{
if
len
(
chat
.
Prompts
)
==
0
{
return
""
,
nil
return
""
,
nil
,
nil
}
var
promptsToAdd
[]
promptInfo
var
totalTokenLength
int
var
systemPromptIncluded
bool
var
images
[]
llm
.
ImageData
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
for
i
:=
len
(
chat
.
Prompts
)
-
1
;
i
>=
0
;
i
--
{
promptText
,
err
:=
promptString
(
model
,
chat
.
Prompts
[
i
],
i
==
len
(
chat
.
Prompts
)
-
1
)
prompt
:=
chat
.
Prompts
[
i
]
promptText
,
err
:=
promptString
(
model
,
prompt
,
i
==
len
(
chat
.
Prompts
)
-
1
)
if
err
!=
nil
{
return
""
,
err
return
""
,
nil
,
err
}
encodedTokens
,
err
:=
loaded
.
runner
.
Encode
(
ctx
,
promptText
)
if
err
!=
nil
{
return
""
,
err
return
""
,
nil
,
err
}
if
totalTokenLength
+
len
(
encodedTokens
)
>
loaded
.
NumCtx
&&
i
!=
len
(
chat
.
Prompts
)
-
1
{
break
// reached max context length, stop adding more prompts
}
for
j
:=
range
prompt
.
Images
{
if
totalTokenLength
+
768
>
loaded
.
NumCtx
{
// this decreases the token length but overestimating is fine
prompt
.
Prompt
=
strings
.
ReplaceAll
(
prompt
.
Prompt
,
fmt
.
Sprintf
(
" [img-%d]"
,
prompt
.
Images
[
j
]
.
ID
),
""
)
continue
}
totalTokenLength
+=
768
images
=
append
(
images
,
prompt
.
Images
[
j
])
}
totalTokenLength
+=
len
(
encodedTokens
)
systemPromptIncluded
=
systemPromptIncluded
||
chat
.
P
rompt
s
[
i
]
.
System
!=
""
promptsToAdd
=
append
(
promptsToAdd
,
promptInfo
{
vars
:
chat
.
P
rompt
s
[
i
]
,
tokenLen
:
len
(
encodedTokens
)})
systemPromptIncluded
=
systemPromptIncluded
||
p
rompt
.
System
!=
""
promptsToAdd
=
append
(
promptsToAdd
,
promptInfo
{
vars
:
p
rompt
,
tokenLen
:
len
(
encodedTokens
)})
}
// ensure the system prompt is included, if not already
...
...
@@ -1264,7 +1290,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
var
err
error
promptsToAdd
,
err
=
includeSystemPrompt
(
ctx
,
chat
.
LastSystem
,
totalTokenLength
,
promptsToAdd
)
if
err
!=
nil
{
return
""
,
err
return
""
,
nil
,
err
}
}
...
...
@@ -1275,11 +1301,12 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
for
i
,
prompt
:=
range
promptsToAdd
{
promptText
,
err
:=
promptString
(
model
,
prompt
.
vars
,
i
==
0
)
if
err
!=
nil
{
return
""
,
err
return
""
,
nil
,
err
}
result
=
promptText
+
result
}
return
result
,
nil
return
result
,
images
,
nil
}
// promptString applies the model template to the prompt
...
...
server/routes_test.go
View file @
bfbf2f7c
...
...
@@ -455,7 +455,8 @@ func Test_ChatPrompt(t *testing.T) {
NumCtx
:
tt
.
numCtx
,
},
}
got
,
err
:=
trimmedPrompt
(
context
.
Background
(),
tt
.
chat
,
m
)
// TODO: add tests for trimming images
got
,
_
,
err
:=
trimmedPrompt
(
context
.
Background
(),
tt
.
chat
,
m
)
if
tt
.
wantErr
!=
""
{
if
err
==
nil
{
t
.
Errorf
(
"ChatPrompt() expected error, got nil"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment