Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
8450bf66
Commit
8450bf66
authored
Jan 31, 2024
by
Michael Yang
Browse files
trim images
parent
b4e11be8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
20 deletions
+41
-20
llm/llama.go
llm/llama.go
+1
-1
server/images.go
server/images.go
+18
-9
server/routes.go
server/routes.go
+22
-10
No files found.
llm/llama.go
View file @
8450bf66
...
@@ -62,7 +62,7 @@ const maxRetries = 3
...
@@ -62,7 +62,7 @@ const maxRetries = 3
type
PredictOpts
struct
{
type
PredictOpts
struct
{
Prompt
string
Prompt
string
Format
string
Format
string
Images
[
]
api
.
ImageData
Images
map
[
int
]
api
.
ImageData
Options
api
.
Options
Options
api
.
Options
}
}
...
...
server/images.go
View file @
8450bf66
...
@@ -58,11 +58,17 @@ type Message struct {
...
@@ -58,11 +58,17 @@ type Message struct {
Content
string
`json:"content"`
Content
string
`json:"content"`
}
}
type
ImageData
struct
{
Rank
int
api
.
ImageData
}
type
PromptVars
struct
{
type
PromptVars
struct
{
System
string
System
string
Prompt
string
Prompt
string
Response
string
Response
string
First
bool
First
bool
Images
[]
ImageData
}
}
// extractParts extracts the parts of the template before and after the {{.Response}} node.
// extractParts extracts the parts of the template before and after the {{.Response}} node.
...
@@ -147,15 +153,13 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
...
@@ -147,15 +153,13 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
}
}
type
ChatHistory
struct
{
type
ChatHistory
struct
{
Prompts
[]
PromptVars
Prompts
[]
PromptVars
CurrentImages
[]
api
.
ImageData
LastSystem
string
LastSystem
string
}
}
// ChatPrompts returns a list of formatted chat prompts from a list of messages
// ChatPrompts returns a list of formatted chat prompts from a list of messages
func
(
m
*
Model
)
ChatPrompts
(
msgs
[]
api
.
Message
)
(
*
ChatHistory
,
error
)
{
func
(
m
*
Model
)
ChatPrompts
(
msgs
[]
api
.
Message
)
(
*
ChatHistory
,
error
)
{
// build the prompt from the list of messages
// build the prompt from the list of messages
var
currentImages
[]
api
.
ImageData
lastSystem
:=
m
.
System
lastSystem
:=
m
.
System
currentVars
:=
PromptVars
{
currentVars
:=
PromptVars
{
First
:
true
,
First
:
true
,
...
@@ -163,6 +167,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
...
@@ -163,6 +167,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
}
}
prompts
:=
[]
PromptVars
{}
prompts
:=
[]
PromptVars
{}
var
images
[]
ImageData
for
_
,
msg
:=
range
msgs
{
for
_
,
msg
:=
range
msgs
{
switch
strings
.
ToLower
(
msg
.
Role
)
{
switch
strings
.
ToLower
(
msg
.
Role
)
{
...
@@ -182,10 +187,15 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
...
@@ -182,10 +187,15 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
currentVars
.
Prompt
=
msg
.
Content
currentVars
.
Prompt
=
msg
.
Content
for
i
:=
range
msg
.
Images
{
for
i
:=
range
msg
.
Images
{
currentVars
.
Prompt
+=
fmt
.
Sprintf
(
" [img-%d]"
,
len
(
currentImages
)
+
i
)
currentVars
.
Prompt
+=
fmt
.
Sprintf
(
" [img-%d]"
,
len
(
images
)
+
i
)
currentVars
.
Images
=
append
(
currentVars
.
Images
,
ImageData
{
Rank
:
len
(
images
)
+
i
,
ImageData
:
msg
.
Images
[
i
],
})
}
}
currentI
mages
=
append
(
currentImages
,
msg
.
Images
...
)
i
mages
=
append
(
images
,
currentVars
.
Images
...
)
case
"assistant"
:
case
"assistant"
:
currentVars
.
Response
=
msg
.
Content
currentVars
.
Response
=
msg
.
Content
prompts
=
append
(
prompts
,
currentVars
)
prompts
=
append
(
prompts
,
currentVars
)
...
@@ -201,9 +211,8 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
...
@@ -201,9 +211,8 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
}
}
return
&
ChatHistory
{
return
&
ChatHistory
{
Prompts
:
prompts
,
Prompts
:
prompts
,
CurrentImages
:
currentImages
,
LastSystem
:
lastSystem
,
LastSystem
:
lastSystem
,
},
nil
},
nil
}
}
...
...
server/routes.go
View file @
8450bf66
...
@@ -312,11 +312,16 @@ func GenerateHandler(c *gin.Context) {
...
@@ -312,11 +312,16 @@ func GenerateHandler(c *gin.Context) {
ch
<-
resp
ch
<-
resp
}
}
images
:=
make
(
map
[
int
]
api
.
ImageData
)
for
i
:=
range
req
.
Images
{
images
[
i
]
=
req
.
Images
[
i
]
}
// Start prediction
// Start prediction
predictReq
:=
llm
.
PredictOpts
{
predictReq
:=
llm
.
PredictOpts
{
Prompt
:
prompt
,
Prompt
:
prompt
,
Format
:
req
.
Format
,
Format
:
req
.
Format
,
Images
:
req
.
I
mages
,
Images
:
i
mages
,
Options
:
opts
,
Options
:
opts
,
}
}
if
err
:=
loaded
.
runner
.
Predict
(
c
.
Request
.
Context
(),
predictReq
,
fn
);
err
!=
nil
{
if
err
:=
loaded
.
runner
.
Predict
(
c
.
Request
.
Context
(),
predictReq
,
fn
);
err
!=
nil
{
...
@@ -1143,7 +1148,8 @@ func ChatHandler(c *gin.Context) {
...
@@ -1143,7 +1148,8 @@ func ChatHandler(c *gin.Context) {
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
return
}
}
prompt
,
err
:=
trimmedPrompt
(
c
.
Request
.
Context
(),
chat
,
model
)
prompt
,
images
,
err
:=
trimmedPrompt
(
c
.
Request
.
Context
(),
chat
,
model
)
if
err
!=
nil
{
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
return
...
@@ -1186,7 +1192,7 @@ func ChatHandler(c *gin.Context) {
...
@@ -1186,7 +1192,7 @@ func ChatHandler(c *gin.Context) {
predictReq
:=
llm
.
PredictOpts
{
predictReq
:=
llm
.
PredictOpts
{
Prompt
:
prompt
,
Prompt
:
prompt
,
Format
:
req
.
Format
,
Format
:
req
.
Format
,
Images
:
chat
.
CurrentI
mages
,
Images
:
i
mages
,
Options
:
opts
,
Options
:
opts
,
}
}
if
err
:=
loaded
.
runner
.
Predict
(
c
.
Request
.
Context
(),
predictReq
,
fn
);
err
!=
nil
{
if
err
:=
loaded
.
runner
.
Predict
(
c
.
Request
.
Context
(),
predictReq
,
fn
);
err
!=
nil
{
...
@@ -1233,25 +1239,27 @@ type promptInfo struct {
...
@@ -1233,25 +1239,27 @@ type promptInfo struct {
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
// while preserving the most recent system message.
// while preserving the most recent system message.
func
trimmedPrompt
(
ctx
context
.
Context
,
chat
*
ChatHistory
,
model
*
Model
)
(
string
,
error
)
{
func
trimmedPrompt
(
ctx
context
.
Context
,
chat
*
ChatHistory
,
model
*
Model
)
(
string
,
map
[
int
]
api
.
ImageData
,
error
)
{
if
len
(
chat
.
Prompts
)
==
0
{
if
len
(
chat
.
Prompts
)
==
0
{
return
""
,
nil
return
""
,
nil
,
nil
}
}
var
promptsToAdd
[]
promptInfo
var
promptsToAdd
[]
promptInfo
var
totalTokenLength
int
var
totalTokenLength
int
var
systemPromptIncluded
bool
var
systemPromptIncluded
bool
images
:=
make
(
map
[
int
]
api
.
ImageData
)
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
for
i
:=
len
(
chat
.
Prompts
)
-
1
;
i
>=
0
;
i
--
{
for
i
:=
len
(
chat
.
Prompts
)
-
1
;
i
>=
0
;
i
--
{
promptText
,
err
:=
promptString
(
model
,
chat
.
Prompts
[
i
],
i
==
len
(
chat
.
Prompts
)
-
1
)
promptText
,
err
:=
promptString
(
model
,
chat
.
Prompts
[
i
],
i
==
len
(
chat
.
Prompts
)
-
1
)
if
err
!=
nil
{
if
err
!=
nil
{
return
""
,
err
return
""
,
nil
,
err
}
}
encodedTokens
,
err
:=
loaded
.
runner
.
Encode
(
ctx
,
promptText
)
encodedTokens
,
err
:=
loaded
.
runner
.
Encode
(
ctx
,
promptText
)
if
err
!=
nil
{
if
err
!=
nil
{
return
""
,
err
return
""
,
nil
,
err
}
}
if
totalTokenLength
+
len
(
encodedTokens
)
>
loaded
.
NumCtx
&&
i
!=
len
(
chat
.
Prompts
)
-
1
{
if
totalTokenLength
+
len
(
encodedTokens
)
>
loaded
.
NumCtx
&&
i
!=
len
(
chat
.
Prompts
)
-
1
{
...
@@ -1261,6 +1269,10 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
...
@@ -1261,6 +1269,10 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
totalTokenLength
+=
len
(
encodedTokens
)
totalTokenLength
+=
len
(
encodedTokens
)
systemPromptIncluded
=
systemPromptIncluded
||
chat
.
Prompts
[
i
]
.
System
!=
""
systemPromptIncluded
=
systemPromptIncluded
||
chat
.
Prompts
[
i
]
.
System
!=
""
promptsToAdd
=
append
(
promptsToAdd
,
promptInfo
{
vars
:
chat
.
Prompts
[
i
],
tokenLen
:
len
(
encodedTokens
)})
promptsToAdd
=
append
(
promptsToAdd
,
promptInfo
{
vars
:
chat
.
Prompts
[
i
],
tokenLen
:
len
(
encodedTokens
)})
for
_
,
image
:=
range
chat
.
Prompts
[
i
]
.
Images
{
images
[
image
.
Rank
]
=
image
.
ImageData
}
}
}
// ensure the system prompt is included, if not already
// ensure the system prompt is included, if not already
...
@@ -1268,7 +1280,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
...
@@ -1268,7 +1280,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
var
err
error
var
err
error
promptsToAdd
,
err
=
includeSystemPrompt
(
ctx
,
chat
.
LastSystem
,
totalTokenLength
,
promptsToAdd
)
promptsToAdd
,
err
=
includeSystemPrompt
(
ctx
,
chat
.
LastSystem
,
totalTokenLength
,
promptsToAdd
)
if
err
!=
nil
{
if
err
!=
nil
{
return
""
,
err
return
""
,
nil
,
err
}
}
}
}
...
@@ -1279,11 +1291,11 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
...
@@ -1279,11 +1291,11 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
for
i
,
prompt
:=
range
promptsToAdd
{
for
i
,
prompt
:=
range
promptsToAdd
{
promptText
,
err
:=
promptString
(
model
,
prompt
.
vars
,
i
==
0
)
promptText
,
err
:=
promptString
(
model
,
prompt
.
vars
,
i
==
0
)
if
err
!=
nil
{
if
err
!=
nil
{
return
""
,
err
return
""
,
nil
,
err
}
}
result
=
promptText
+
result
result
=
promptText
+
result
}
}
return
result
,
nil
return
result
,
images
,
nil
}
}
// promptString applies the model template to the prompt
// promptString applies the model template to the prompt
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment