Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
9bbddc37
Unverified
Commit
9bbddc37
authored
Jul 09, 2024
by
Michael Yang
Committed by
GitHub
Jul 09, 2024
Browse files
Merge pull request #5126 from ollama/mxyng/messages
update message processing
parents
e4ff7329
326363b3
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
677 additions
and
709 deletions
+677
-709
llm/server.go
llm/server.go
+1
-1
server/images.go
server/images.go
+13
-4
server/prompt.go
server/prompt.go
+53
-187
server/prompt_test.go
server/prompt_test.go
+155
-166
server/routes.go
server/routes.go
+157
-339
template/template.go
template/template.go
+138
-8
template/template_test.go
template/template_test.go
+160
-4
No files found.
llm/server.go
View file @
9bbddc37
...
@@ -679,7 +679,7 @@ type CompletionRequest struct {
...
@@ -679,7 +679,7 @@ type CompletionRequest struct {
Prompt
string
Prompt
string
Format
string
Format
string
Images
[]
ImageData
Images
[]
ImageData
Options
api
.
Options
Options
*
api
.
Options
}
}
type
CompletionResponse
struct
{
type
CompletionResponse
struct
{
...
...
server/images.go
View file @
9bbddc37
...
@@ -34,6 +34,8 @@ import (
...
@@ -34,6 +34,8 @@ import (
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/version"
)
)
var
errCapabilityCompletion
=
errors
.
New
(
"completion"
)
type
Capability
string
type
Capability
string
const
CapabilityCompletion
=
Capability
(
"completion"
)
const
CapabilityCompletion
=
Capability
(
"completion"
)
...
@@ -62,7 +64,10 @@ type Model struct {
...
@@ -62,7 +64,10 @@ type Model struct {
Template
*
template
.
Template
Template
*
template
.
Template
}
}
func
(
m
*
Model
)
Has
(
caps
...
Capability
)
bool
{
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
// any missing or unknown capabilities
func
(
m
*
Model
)
CheckCapabilities
(
caps
...
Capability
)
error
{
var
errs
[]
error
for
_
,
cap
:=
range
caps
{
for
_
,
cap
:=
range
caps
{
switch
cap
{
switch
cap
{
case
CapabilityCompletion
:
case
CapabilityCompletion
:
...
@@ -81,15 +86,19 @@ func (m *Model) Has(caps ...Capability) bool {
...
@@ -81,15 +86,19 @@ func (m *Model) Has(caps ...Capability) bool {
}
}
if
_
,
ok
:=
ggml
.
KV
()[
fmt
.
Sprintf
(
"%s.pooling_type"
,
ggml
.
KV
()
.
Architecture
())];
ok
{
if
_
,
ok
:=
ggml
.
KV
()[
fmt
.
Sprintf
(
"%s.pooling_type"
,
ggml
.
KV
()
.
Architecture
())];
ok
{
return
false
errs
=
append
(
errs
,
errCapabilityCompletion
)
}
}
default
:
default
:
slog
.
Error
(
"unknown capability"
,
"capability"
,
cap
)
slog
.
Error
(
"unknown capability"
,
"capability"
,
cap
)
return
f
alse
return
f
mt
.
Errorf
(
"unknown capability: %s"
,
cap
)
}
}
}
}
return
true
if
err
:=
errors
.
Join
(
errs
...
);
err
!=
nil
{
return
fmt
.
Errorf
(
"missing capabilities: %w"
,
errors
.
Join
(
errs
...
))
}
return
nil
}
}
func
(
m
*
Model
)
String
()
string
{
func
(
m
*
Model
)
String
()
string
{
...
...
server/prompt.go
View file @
9bbddc37
package
server
package
server
import
(
import
(
"fmt"
"bytes"
"context"
"log/slog"
"log/slog"
"strings"
"slices"
"text/template/parse"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/template"
)
)
// isResponseNode checks if the node contains .Response
type
tokenizeFunc
func
(
context
.
Context
,
string
)
([]
int
,
error
)
func
isResponseNode
(
node
*
parse
.
ActionNode
)
bool
{
for
_
,
cmd
:=
range
node
.
Pipe
.
Cmds
{
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
for
_
,
arg
:=
range
cmd
.
Args
{
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
if
fieldNode
,
ok
:=
arg
.
(
*
parse
.
FieldNode
);
ok
&&
len
(
fieldNode
.
Ident
)
>
0
{
// latest message and 2) system messages
if
fieldNode
.
Ident
[
0
]
==
"Response"
{
func
chatPrompt
(
ctx
context
.
Context
,
m
*
Model
,
tokenize
tokenizeFunc
,
opts
*
api
.
Options
,
msgs
[]
api
.
Message
)
(
prompt
string
,
images
[]
llm
.
ImageData
,
_
error
)
{
return
true
// pull out any system messages which should always be included in the prompt
}
var
system
[]
api
.
Message
}
msgs
=
slices
.
DeleteFunc
(
msgs
,
func
(
m
api
.
Message
)
bool
{
}
if
m
.
Role
==
"system"
{
}
system
=
append
(
system
,
m
)
return
false
return
true
}
// formatTemplateForResponse formats the template AST to:
// 1. remove all nodes after the first .Response (if generate=true)
// 2. add a .Response node to the end if it doesn't exist
// TODO(jmorganca): this should recursively cut the template before the first .Response
func
formatTemplateForResponse
(
tmpl
*
template
.
Template
,
generate
bool
)
{
var
found
bool
for
i
,
node
:=
range
tmpl
.
Tree
.
Root
.
Nodes
{
if
actionNode
,
ok
:=
node
.
(
*
parse
.
ActionNode
);
ok
{
if
isResponseNode
(
actionNode
)
{
found
=
true
if
generate
{
tmpl
.
Tree
.
Root
.
Nodes
=
tmpl
.
Tree
.
Root
.
Nodes
[
:
i
+
1
]
break
}
}
}
}
}
if
!
found
{
// add the response node if it doesn't exist
responseFieldNode
:=
&
parse
.
FieldNode
{
NodeType
:
parse
.
NodeField
,
Ident
:
[]
string
{
"Response"
}}
responsePipeNode
:=
&
parse
.
PipeNode
{
NodeType
:
parse
.
NodePipe
,
Cmds
:
[]
*
parse
.
CommandNode
{{
NodeType
:
parse
.
NodeCommand
,
Args
:
[]
parse
.
Node
{
responseFieldNode
}}}}
responseActionNode
:=
&
parse
.
ActionNode
{
NodeType
:
parse
.
NodeAction
,
Pipe
:
responsePipeNode
}
tmpl
.
Tree
.
Root
.
Nodes
=
append
(
tmpl
.
Tree
.
Root
.
Nodes
,
responseActionNode
)
}
}
// Prompt renders a prompt from a template. If generate is set to true,
// the response and parts of the template following it are not rendered
func
Prompt
(
tmpl
*
template
.
Template
,
system
,
prompt
,
response
string
,
generate
bool
)
(
string
,
error
)
{
formatTemplateForResponse
(
tmpl
,
generate
)
vars
:=
map
[
string
]
any
{
"System"
:
system
,
"Prompt"
:
prompt
,
"Response"
:
response
,
}
var
sb
strings
.
Builder
if
err
:=
tmpl
.
Execute
(
&
sb
,
vars
);
err
!=
nil
{
return
""
,
err
}
return
sb
.
String
(),
nil
return
false
}
})
func
countTokens
(
tmpl
*
template
.
Template
,
system
string
,
prompt
string
,
response
string
,
encode
func
(
string
)
([]
int
,
error
))
(
int
,
error
)
{
if
len
(
system
)
==
0
&&
m
.
System
!=
""
{
rendered
,
err
:=
Prompt
(
tmpl
,
system
,
prompt
,
response
,
false
)
// add model system prompt since it wasn't provided
if
err
!=
nil
{
system
=
append
(
system
,
api
.
Message
{
Role
:
"system"
,
Content
:
m
.
System
})
return
0
,
err
}
}
tokens
,
err
:=
encode
(
rendered
)
// always include the last message
if
err
!=
nil
{
n
:=
len
(
msgs
)
-
1
slog
.
Error
(
"failed to encode prompt"
,
"err"
,
err
)
// in reverse, find all messages that fit into context window
return
0
,
err
for
i
:=
n
-
1
;
i
>=
0
;
i
--
{
}
var
b
bytes
.
Buffer
if
err
:=
m
.
Template
.
Execute
(
&
b
,
template
.
Values
{
Messages
:
append
(
system
,
msgs
[
i
:
]
...
)});
err
!=
nil
{
return
len
(
tokens
),
err
return
""
,
nil
,
err
}
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
func
ChatPrompt
(
tmpl
*
template
.
Template
,
messages
[]
api
.
Message
,
window
int
,
encode
func
(
string
)
([]
int
,
error
))
(
string
,
error
)
{
type
prompt
struct
{
System
string
Prompt
string
Response
string
images
[]
int
tokens
int
}
var
p
prompt
// iterate through messages to build up {system,user,response} prompts
var
imgId
int
var
prompts
[]
prompt
for
_
,
msg
:=
range
messages
{
switch
strings
.
ToLower
(
msg
.
Role
)
{
case
"system"
:
if
p
.
System
!=
""
||
p
.
Prompt
!=
""
||
p
.
Response
!=
""
{
prompts
=
append
(
prompts
,
p
)
p
=
prompt
{}
}
p
.
System
=
msg
.
Content
case
"user"
:
if
p
.
Prompt
!=
""
||
p
.
Response
!=
""
{
prompts
=
append
(
prompts
,
p
)
p
=
prompt
{}
}
var
sb
strings
.
Builder
for
range
msg
.
Images
{
fmt
.
Fprintf
(
&
sb
,
"[img-%d] "
,
imgId
)
p
.
images
=
append
(
p
.
images
,
imgId
)
imgId
+=
1
}
sb
.
WriteString
(
msg
.
Content
)
p
.
Prompt
=
sb
.
String
()
case
"assistant"
:
if
p
.
Response
!=
""
{
prompts
=
append
(
prompts
,
p
)
p
=
prompt
{}
}
p
.
Response
=
msg
.
Content
default
:
return
""
,
fmt
.
Errorf
(
"invalid role: %s, role must be one of [system, user, assistant]"
,
msg
.
Role
)
}
}
}
// add final prompt
if
p
.
System
!=
""
||
p
.
Prompt
!=
""
||
p
.
Response
!=
""
{
prompts
=
append
(
prompts
,
p
)
}
// calculate token lengths for each prompt, estimating 768 tokens per images
s
,
err
:=
tokenize
(
ctx
,
b
.
String
())
for
i
,
p
:=
range
prompts
{
tokens
,
err
:=
countTokens
(
tmpl
,
p
.
System
,
p
.
Prompt
,
p
.
Response
,
encode
)
if
err
!=
nil
{
if
err
!=
nil
{
return
""
,
err
return
""
,
nil
,
err
}
}
prompts
[
i
]
.
tokens
=
tokens
+
len
(
prompts
[
i
]
.
images
)
*
768
c
:=
len
(
s
)
}
if
m
.
ProjectorPaths
!=
nil
{
for
_
,
m
:=
range
msgs
[
i
:
]
{
// truncate images and prompts starting from the beginning of the list
// images are represented as 768 sized embeddings
// until either one prompt remains or the total tokens fits the context window
// TODO: get embedding length from project metadata
// TODO (jmorganca): this doesn't account for the context window room required for the response
c
+=
768
*
len
(
m
.
Images
)
for
{
}
var
required
int
for
_
,
p
:=
range
prompts
{
required
+=
p
.
tokens
}
}
required
+=
1
// for bos token
if
c
>
opts
.
NumCtx
{
slog
.
Debug
(
"truncating input messages which exceed context length"
,
"truncated"
,
len
(
msgs
[
i
:
]))
if
required
<=
window
{
slog
.
Debug
(
"prompt now fits in context window"
,
"required"
,
required
,
"window"
,
window
)
break
break
}
else
{
n
=
i
}
}
}
prompt
:=
&
prompts
[
0
]
// truncate any messages that do not fit into the context window
var
b
bytes
.
Buffer
if
len
(
prompt
.
images
)
>
1
{
if
err
:=
m
.
Template
.
Execute
(
&
b
,
template
.
Values
{
Messages
:
append
(
system
,
msgs
[
n
:
]
...
)});
err
!=
nil
{
img
:=
prompt
.
images
[
0
]
return
""
,
nil
,
err
slog
.
Debug
(
"prompt longer than context window, removing image"
,
"id"
,
img
,
"required"
,
required
,
"window"
,
window
)
prompt
.
images
=
prompt
.
images
[
1
:
]
prompt
.
Prompt
=
strings
.
Replace
(
prompt
.
Prompt
,
fmt
.
Sprintf
(
" [img-%d]"
,
img
),
""
,
1
)
prompt
.
tokens
-=
768
continue
}
if
len
(
prompts
)
>
1
{
slog
.
Debug
(
"required tokens longer than context window, removing first prompt"
,
"prompt"
,
prompts
[
0
]
.
tokens
,
"required"
,
required
,
"window"
,
window
)
system
:=
prompt
.
System
prompts
=
prompts
[
1
:
]
if
system
!=
""
&&
prompts
[
0
]
.
System
==
""
{
prompts
[
0
]
.
System
=
system
tokens
,
err
:=
countTokens
(
tmpl
,
prompts
[
0
]
.
System
,
prompts
[
0
]
.
Prompt
,
prompts
[
0
]
.
Response
,
encode
)
if
err
!=
nil
{
return
""
,
err
}
prompts
[
0
]
.
tokens
=
tokens
+
len
(
prompts
[
0
]
.
images
)
*
768
}
continue
}
// stop truncating if there's only one prompt left
break
}
}
var
sb
strings
.
Builder
for
_
,
m
:=
range
msgs
[
n
:
]
{
for
i
,
p
:=
range
prompt
s
{
for
_
,
i
:=
range
m
.
Image
s
{
// last prompt should leave the response unrendered (for completion)
images
=
append
(
images
,
llm
.
ImageData
{
rendered
,
err
:=
Prompt
(
tmpl
,
p
.
System
,
p
.
Prompt
,
p
.
Response
,
i
==
len
(
prompts
)
-
1
)
ID
:
len
(
images
),
if
err
!=
nil
{
Data
:
i
,
return
""
,
err
})
}
}
sb
.
WriteString
(
rendered
)
}
}
return
s
b
.
String
(),
nil
return
b
.
String
(),
images
,
nil
}
}
server/prompt_test.go
View file @
9bbddc37
package
server
package
server
import
(
import
(
"bytes"
"context"
"strings"
"strings"
"testing"
"testing"
...
@@ -8,208 +10,195 @@ import (
...
@@ -8,208 +10,195 @@ import (
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/template"
)
)
func
TestPrompt
(
t
*
testing
.
T
)
{
func
tokenize
(
_
context
.
Context
,
s
string
)
(
tokens
[]
int
,
err
error
)
{
tests
:=
[]
struct
{
for
range
strings
.
Fields
(
s
)
{
name
string
tokens
=
append
(
tokens
,
len
(
tokens
))
template
string
system
string
prompt
string
response
string
generate
bool
want
string
}{
{
name
:
"simple prompt"
,
template
:
"[INST] {{ .System }} {{ .Prompt }} [/INST]"
,
system
:
"You are a Wizard."
,
prompt
:
"What are the potion ingredients?"
,
want
:
"[INST] You are a Wizard. What are the potion ingredients? [/INST]"
,
},
{
name
:
"implicit response"
,
template
:
"[INST] {{ .System }} {{ .Prompt }} [/INST]"
,
system
:
"You are a Wizard."
,
prompt
:
"What are the potion ingredients?"
,
response
:
"I don't know."
,
want
:
"[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know."
,
},
{
name
:
"response"
,
template
:
"[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}"
,
system
:
"You are a Wizard."
,
prompt
:
"What are the potion ingredients?"
,
response
:
"I don't know."
,
want
:
"[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know."
,
},
{
name
:
"cut"
,
template
:
"<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>"
,
system
:
"You are a Wizard."
,
prompt
:
"What are the potion ingredients?"
,
response
:
"I don't know."
,
generate
:
true
,
want
:
"<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know."
,
},
{
name
:
"nocut"
,
template
:
"<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>"
,
system
:
"You are a Wizard."
,
prompt
:
"What are the potion ingredients?"
,
response
:
"I don't know."
,
want
:
"<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>"
,
},
}
}
for
_
,
tc
:=
range
tests
{
return
t
.
Run
(
tc
.
name
,
func
(
t
*
testing
.
T
)
{
tmpl
,
err
:=
template
.
Parse
(
tc
.
template
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
got
,
err
:=
Prompt
(
tmpl
,
tc
.
system
,
tc
.
prompt
,
tc
.
response
,
tc
.
generate
)
if
err
!=
nil
{
t
.
Errorf
(
"error = %v"
,
err
)
}
if
got
!=
tc
.
want
{
t
.
Errorf
(
"got = %v, want %v"
,
got
,
tc
.
want
)
}
})
}
}
}
func
TestChatPrompt
(
t
*
testing
.
T
)
{
func
TestChatPrompt
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
type
expect
struct
{
name
string
prompt
string
template
string
images
[][]
byte
messages
[]
api
.
Message
}
window
int
want
string
cases
:=
[]
struct
{
name
string
limit
int
msgs
[]
api
.
Message
expect
}{
}{
{
{
name
:
"simple prompt"
,
name
:
"messages"
,
template
:
"[INST] {{ .Prompt }} [/INST]"
,
limit
:
64
,
messages
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"Hello"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"user"
,
Content
:
"A test. And a thumping good one at that, I'd wager."
},
},
expect
:
expect
{
prompt
:
"You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. "
,
},
},
window
:
1024
,
want
:
"[INST] Hello [/INST]"
,
},
{
name
:
"with system message"
,
template
:
"[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]"
,
messages
:
[]
api
.
Message
{
{
Role
:
"system"
,
Content
:
"You are a Wizard."
},
{
Role
:
"user"
,
Content
:
"Hello"
},
},
window
:
1024
,
want
:
"[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]"
,
},
},
{
{
name
:
"with response"
,
name
:
"truncate messages"
,
template
:
"[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}"
,
limit
:
1
,
messages
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"system"
,
Content
:
"You are a Wizard."
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"Hello"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I am?"
},
{
Role
:
"user"
,
Content
:
"A test. And a thumping good one at that, I'd wager."
},
},
},
window
:
1024
,
expect
:
expect
{
want
:
"[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?"
,
prompt
:
"A test. And a thumping good one at that, I'd wager. "
,
},
},
},
{
{
name
:
"with implicit response"
,
name
:
"truncate messages with image"
,
template
:
"[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]"
,
limit
:
64
,
messages
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"system"
,
Content
:
"You are a Wizard."
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"Hello"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I am?"
},
{
Role
:
"user"
,
Content
:
"A test. And a thumping good one at that, I'd wager."
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
},
},
window
:
1024
,
expect
:
expect
{
want
:
"[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?"
,
prompt
:
"[img-0] A test. And a thumping good one at that, I'd wager. "
,
images
:
[][]
byte
{
[]
byte
(
"something"
),
},
},
},
},
{
{
name
:
"with conversation"
,
name
:
"truncate messages with images"
,
template
:
"[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} "
,
limit
:
64
,
messages
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"system"
,
Content
:
"You are a Wizard."
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Content
:
"What are the potion ingredients?"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"sugar"
},
{
Role
:
"user"
,
Content
:
"A test. And a thumping good one at that, I'd wager."
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"somethingelse"
)}},
{
Role
:
"user"
,
Content
:
"Anything else?"
},
},
},
expect
:
expect
{
window
:
1024
,
prompt
:
"[img-0] A test. And a thumping good one at that, I'd wager. "
,
want
:
"[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] "
,
images
:
[][]
byte
{
[]
byte
(
"somethingelse"
),
},
},
},
},
{
{
name
:
"with truncation"
,
name
:
"messages with images"
,
template
:
"{{ .System }} {{ .Prompt }} {{ .Response }} "
,
limit
:
2048
,
messages
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"system"
,
Content
:
"You are a Wizard."
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Content
:
"Hello"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I am?"
},
{
Role
:
"user"
,
Content
:
"A test. And a thumping good one at that, I'd wager."
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"somethingelse"
)}},
{
Role
:
"user"
,
Content
:
"Why is the sky blue?"
},
},
{
Role
:
"assistant"
,
Content
:
"The sky is blue from rayleigh scattering"
},
expect
:
expect
{
},
prompt
:
"[img-0] You're a test, Harry! I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. "
,
window
:
10
,
images
:
[][]
byte
{
want
:
"You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering"
,
[]
byte
(
"something"
),
[]
byte
(
"somethingelse"
),
},
},
},
},
{
{
name
:
"images"
,
name
:
"message with image tag"
,
template
:
"{{ .System }} {{ .Prompt }}"
,
limit
:
2048
,
messages
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"system"
,
Content
:
"You are a Wizard."
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry! [img]"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Content
:
"Hello"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"base64"
)}},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
},
{
Role
:
"user"
,
Content
:
"A test. And a thumping good one at that, I'd wager."
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"somethingelse"
)}},
window
:
1024
,
},
want
:
"You are a Wizard. [img-0] Hello"
,
expect
:
expect
{
prompt
:
"You're a test, Harry! [img-0] I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. "
,
images
:
[][]
byte
{
[]
byte
(
"something"
),
[]
byte
(
"somethingelse"
),
},
},
},
},
{
{
name
:
"images truncated"
,
name
:
"messages with interleaved images"
,
template
:
"{{ .System }} {{ .Prompt }}"
,
limit
:
2048
,
messages
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"system"
,
Content
:
"You are a Wizard."
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"Hello"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"img1"
),
[]
byte
(
"img2"
)}},
{
Role
:
"user"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
},
{
Role
:
"user"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"somethingelse"
)}},
window
:
1024
,
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
want
:
"You are a Wizard. [img-0] [img-1] Hello"
,
{
Role
:
"user"
,
Content
:
"A test. And a thumping good one at that, I'd wager."
},
},
expect
:
expect
{
prompt
:
"You're a test, Harry!
\n\n
[img-0]
\n\n
[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. "
,
images
:
[][]
byte
{
[]
byte
(
"something"
),
[]
byte
(
"somethingelse"
),
},
},
},
},
{
{
name
:
"empty list"
,
name
:
"truncate message with interleaved images"
,
template
:
"{{ .System }} {{ .Prompt }}"
,
limit
:
1024
,
messages
:
[]
api
.
Message
{},
msgs
:
[]
api
.
Message
{
window
:
1024
,
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
want
:
""
,
{
Role
:
"user"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"somethingelse"
)}},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"user"
,
Content
:
"A test. And a thumping good one at that, I'd wager."
},
},
expect
:
expect
{
prompt
:
"[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. "
,
images
:
[][]
byte
{
[]
byte
(
"somethingelse"
),
},
},
},
},
{
{
name
:
"empty prompt"
,
name
:
"message with system prompt"
,
template
:
"[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} "
,
limit
:
2048
,
messages
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
""
},
{
Role
:
"system"
,
Content
:
"You are the Test Who Lived."
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"user"
,
Content
:
"A test. And a thumping good one at that, I'd wager."
},
},
expect
:
expect
{
prompt
:
"You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. "
,
},
},
window
:
1024
,
want
:
""
,
},
},
}
}
encode
:=
func
(
s
string
)
([]
int
,
error
)
{
tmpl
,
err
:=
template
.
Parse
(
`
words
:=
strings
.
Fields
(
s
)
{{- if .System }}{{ .System }} {{ end }}
return
make
([]
int
,
len
(
words
)),
nil
{{- if .Prompt }}{{ .Prompt }} {{ end }}
{{- if .Response }}{{ .Response }} {{ end }}`
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
}
for
_
,
tc
:=
range
tests
{
for
_
,
tt
:=
range
cases
{
t
.
Run
(
tc
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
tmpl
,
err
:=
template
.
Parse
(
tc
.
template
)
model
:=
Model
{
Template
:
tmpl
,
ProjectorPaths
:
[]
string
{
"vision"
}}
opts
:=
api
.
Options
{
Runner
:
api
.
Runner
{
NumCtx
:
tt
.
limit
}}
prompt
,
images
,
err
:=
chatPrompt
(
context
.
TODO
(),
&
model
,
tokenize
,
&
opts
,
tt
.
msgs
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Fatal
(
err
)
t
.
Fatal
(
err
)
}
}
got
,
err
:=
ChatPrompt
(
tmpl
,
tc
.
messages
,
tc
.
window
,
encode
)
if
tt
.
prompt
!=
prompt
{
if
err
!=
nil
{
t
.
Errorf
(
"expected %q, got %q"
,
tt
.
prompt
,
prompt
)
t
.
Errorf
(
"error = %v"
,
err
)
}
}
if
got
!=
tc
.
want
{
if
len
(
images
)
!=
len
(
tt
.
images
)
{
t
.
Errorf
(
"got: %q, want: %q"
,
got
,
tc
.
want
)
t
.
Fatalf
(
"expected %d images, got %d"
,
len
(
tt
.
images
),
len
(
images
))
}
for
i
:=
range
images
{
if
images
[
i
]
.
ID
!=
i
{
t
.
Errorf
(
"expected ID %d, got %d"
,
i
,
images
[
i
]
.
ID
)
}
if
!
bytes
.
Equal
(
images
[
i
]
.
Data
,
tt
.
images
[
i
])
{
t
.
Errorf
(
"expected %q, got %q"
,
tt
.
images
[
i
],
images
[
i
])
}
}
}
})
})
}
}
...
...
server/routes.go
View file @
9bbddc37
package
server
package
server
import
(
import
(
"bytes"
"cmp"
"cmp"
"context"
"context"
"encoding/json"
"encoding/json"
"errors"
"errors"
"fmt"
"fmt"
"io"
"io"
"io/fs"
"log/slog"
"log/slog"
"net"
"net"
"net/http"
"net/http"
...
@@ -54,6 +54,8 @@ func init() {
...
@@ -54,6 +54,8 @@ func init() {
gin
.
SetMode
(
mode
)
gin
.
SetMode
(
mode
)
}
}
var
errRequired
=
errors
.
New
(
"is required"
)
func
modelOptions
(
model
*
Model
,
requestOpts
map
[
string
]
interface
{})
(
api
.
Options
,
error
)
{
func
modelOptions
(
model
*
Model
,
requestOpts
map
[
string
]
interface
{})
(
api
.
Options
,
error
)
{
opts
:=
api
.
DefaultOptions
()
opts
:=
api
.
DefaultOptions
()
if
err
:=
opts
.
FromMap
(
model
.
Options
);
err
!=
nil
{
if
err
:=
opts
.
FromMap
(
model
.
Options
);
err
!=
nil
{
...
@@ -67,163 +69,140 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
...
@@ -67,163 +69,140 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
return
opts
,
nil
return
opts
,
nil
}
}
func
isSupportedImageType
(
image
[]
byte
)
bool
{
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
contentType
:=
http
.
DetectContentType
(
image
)
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
allowedTypes
:=
[]
string
{
"image/jpeg"
,
"image/jpg"
,
"image/png"
}
func
(
s
*
Server
)
scheduleRunner
(
ctx
context
.
Context
,
name
string
,
caps
[]
Capability
,
requestOpts
map
[
string
]
any
,
keepAlive
*
api
.
Duration
)
(
llm
.
LlamaServer
,
*
Model
,
*
api
.
Options
,
error
)
{
return
slices
.
Contains
(
allowedTypes
,
contentType
)
if
name
==
""
{
return
nil
,
nil
,
nil
,
fmt
.
Errorf
(
"model %w"
,
errRequired
)
}
model
,
err
:=
GetModel
(
name
)
if
err
!=
nil
{
return
nil
,
nil
,
nil
,
err
}
if
err
:=
model
.
CheckCapabilities
(
caps
...
);
err
!=
nil
{
return
nil
,
nil
,
nil
,
fmt
.
Errorf
(
"%s %w"
,
name
,
err
)
}
opts
,
err
:=
modelOptions
(
model
,
requestOpts
)
if
err
!=
nil
{
return
nil
,
nil
,
nil
,
err
}
runnerCh
,
errCh
:=
s
.
sched
.
GetRunner
(
ctx
,
model
,
opts
,
keepAlive
)
var
runner
*
runnerRef
select
{
case
runner
=
<-
runnerCh
:
case
err
=
<-
errCh
:
return
nil
,
nil
,
nil
,
err
}
return
runner
.
llama
,
model
,
&
opts
,
nil
}
}
func
(
s
*
Server
)
GenerateHandler
(
c
*
gin
.
Context
)
{
func
(
s
*
Server
)
GenerateHandler
(
c
*
gin
.
Context
)
{
checkpointStart
:=
time
.
Now
()
var
req
api
.
GenerateRequest
var
req
api
.
GenerateRequest
err
:=
c
.
ShouldBindJSON
(
&
req
)
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
errors
.
Is
(
err
,
io
.
EOF
)
{
switch
{
case
errors
.
Is
(
err
,
io
.
EOF
)
:
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"missing request body"
})
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"missing request body"
})
return
return
case
err
!=
nil
:
}
else
if
err
!=
nil
{
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
return
}
}
// validate the request
if
req
.
Format
!=
""
&&
req
.
Format
!=
"json"
{
switch
{
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"format must be empty or
\"
json
\"
"
})
case
req
.
Model
==
""
:
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"model is required"
})
return
case
len
(
req
.
Format
)
>
0
&&
req
.
Format
!=
"json"
:
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"format must be json"
})
return
return
case
req
.
Raw
&&
(
req
.
Template
!=
""
||
req
.
System
!=
""
||
len
(
req
.
Context
)
>
0
)
:
}
else
if
req
.
Raw
&&
(
req
.
Template
!=
""
||
req
.
System
!=
""
||
len
(
req
.
Context
)
>
0
)
{
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"raw mode does not support template, system, or context"
})
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"raw mode does not support template, system, or context"
})
return
return
}
}
for
_
,
img
:=
range
req
.
Images
{
caps
:=
[]
Capability
{
CapabilityCompletion
}
if
!
isSupportedImageType
(
img
)
{
r
,
m
,
opts
,
err
:=
s
.
scheduleRunner
(
c
.
Request
.
Context
(),
req
.
Model
,
caps
,
req
.
Options
,
req
.
KeepAlive
)
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"unsupported image format"
})
if
errors
.
Is
(
err
,
errCapabilityCompletion
)
{
return
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"%q does not support generate"
,
req
.
Model
)})
}
}
model
,
err
:=
GetModel
(
req
.
Model
)
if
err
!=
nil
{
var
pErr
*
fs
.
PathError
if
errors
.
As
(
err
,
&
pErr
)
{
c
.
JSON
(
http
.
StatusNotFound
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"model '%s' not found, try pulling it first"
,
req
.
Model
)})
return
}
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
if
!
model
.
Has
(
CapabilityCompletion
)
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"%s does not support generate"
,
req
.
Model
)})
return
}
opts
,
err
:=
modelOptions
(
model
,
req
.
Options
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
return
}
}
else
if
err
!=
nil
{
handleScheduleError
(
c
,
req
.
Model
,
err
)
rCh
,
eCh
:=
s
.
sched
.
GetRunner
(
c
.
Request
.
Context
(),
model
,
opts
,
req
.
KeepAlive
)
var
runner
*
runnerRef
select
{
case
runner
=
<-
rCh
:
case
err
=
<-
eCh
:
handleErrorResponse
(
c
,
err
)
return
return
}
}
// an empty request loads the model
if
req
.
Prompt
==
""
{
// note: for a short while template was used in lieu
// of `raw` mode so we need to check for it too
if
req
.
Prompt
==
""
&&
req
.
Template
==
""
&&
req
.
System
==
""
{
c
.
JSON
(
http
.
StatusOK
,
api
.
GenerateResponse
{
c
.
JSON
(
http
.
StatusOK
,
api
.
GenerateResponse
{
CreatedAt
:
time
.
Now
()
.
UTC
(),
Model
:
req
.
Model
,
Model
:
req
.
Model
,
CreatedAt
:
time
.
Now
()
.
UTC
(),
Done
:
true
,
Done
:
true
,
DoneReason
:
"load"
,
DoneReason
:
"load"
,
})
})
return
return
}
}
tmpl
,
err
:=
template
.
Parse
(
req
.
Template
)
images
:=
make
([]
llm
.
ImageData
,
len
(
req
.
Images
))
if
err
!=
nil
{
for
i
:=
range
req
.
Images
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
images
[
i
]
=
llm
.
ImageData
{
ID
:
i
,
Data
:
req
.
Images
[
i
]}
return
}
}
checkpointLoaded
:=
time
.
Now
()
prompt
:=
req
.
Prompt
if
!
req
.
Raw
{
var
prompt
string
var
msgs
[]
api
.
Message
switch
{
if
req
.
System
!=
""
{
case
req
.
Raw
:
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"system"
,
Content
:
req
.
System
})
prompt
=
req
.
Prompt
}
else
if
m
.
System
!=
""
{
case
req
.
Prompt
!=
""
:
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"system"
,
Content
:
m
.
System
})
if
req
.
Template
==
""
{
tmpl
=
model
.
Template
}
}
if
req
.
System
==
""
{
for
_
,
i
:=
range
images
{
req
.
System
=
model
.
System
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"user"
,
Content
:
fmt
.
Sprintf
(
"[img-%d]"
,
i
.
ID
)})
}
}
slog
.
Debug
(
"generate handler"
,
"prompt"
,
req
.
Prompt
)
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"user"
,
Content
:
req
.
Prompt
})
slog
.
Debug
(
"generate handler"
,
"template"
,
req
.
Template
)
slog
.
Debug
(
"generate handler"
,
"system"
,
req
.
System
)
var
sb
strings
.
Builder
for
i
:=
range
req
.
Images
{
fmt
.
Fprintf
(
&
sb
,
"[img-%d] "
,
i
)
}
sb
.
WriteString
(
req
.
Prompt
)
tmpl
:=
m
.
Template
if
req
.
Template
!=
""
{
p
,
err
:=
Prompt
(
tmpl
,
req
.
System
,
sb
.
String
(),
""
,
true
)
tmpl
,
err
=
template
.
Parse
(
req
.
Template
)
if
err
!=
nil
{
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
return
}
}
}
sb
.
Reset
()
var
b
bytes
.
Buffer
if
req
.
Context
!=
nil
{
if
req
.
Context
!=
nil
{
prev
,
err
:=
r
unner
.
llama
.
Detokenize
(
c
.
Request
.
Context
(),
req
.
Context
)
s
,
err
:=
r
.
Detokenize
(
c
.
Request
.
Context
(),
req
.
Context
)
if
err
!=
nil
{
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
return
}
}
s
b
.
WriteString
(
prev
)
b
.
WriteString
(
s
)
}
}
sb
.
WriteString
(
p
)
if
err
:=
tmpl
.
Execute
(
&
b
,
template
.
Values
{
Messages
:
msgs
});
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
prompt
=
s
b
.
String
()
prompt
=
b
.
String
()
}
}
slog
.
Debug
(
"generate
handler
"
,
"prompt"
,
prompt
)
slog
.
Debug
(
"generate
request
"
,
"prompt"
,
prompt
,
"images"
,
images
)
ch
:=
make
(
chan
any
)
ch
:=
make
(
chan
any
)
var
generated
strings
.
Builder
go
func
()
{
go
func
()
{
defer
close
(
ch
)
defer
close
(
ch
)
if
err
:=
r
.
Completion
(
c
.
Request
.
Context
(),
llm
.
CompletionRequest
{
fn
:=
func
(
r
llm
.
CompletionResponse
)
{
Prompt
:
prompt
,
// Build up the full response
Images
:
images
,
if
_
,
err
:=
generated
.
WriteString
(
r
.
Content
);
err
!=
nil
{
Format
:
req
.
Format
,
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
Options
:
opts
,
return
},
func
(
r
llm
.
CompletionResponse
)
{
}
ch
<-
api
.
GenerateResponse
{
resp
:=
api
.
GenerateResponse
{
Model
:
req
.
Model
,
Model
:
req
.
Model
,
CreatedAt
:
time
.
Now
()
.
UTC
(),
CreatedAt
:
time
.
Now
()
.
UTC
(),
Done
:
r
.
Done
,
Response
:
r
.
Content
,
Response
:
r
.
Content
,
Done
:
r
.
Done
,
DoneReason
:
r
.
DoneReason
,
DoneReason
:
r
.
DoneReason
,
Metrics
:
api
.
Metrics
{
Metrics
:
api
.
Metrics
{
PromptEvalCount
:
r
.
PromptEvalCount
,
PromptEvalCount
:
r
.
PromptEvalCount
,
...
@@ -232,77 +211,35 @@ func (s *Server) GenerateHandler(c *gin.Context) {
...
@@ -232,77 +211,35 @@ func (s *Server) GenerateHandler(c *gin.Context) {
EvalDuration
:
r
.
EvalDuration
,
EvalDuration
:
r
.
EvalDuration
,
},
},
}
}
});
err
!=
nil
{
if
r
.
Done
{
resp
.
TotalDuration
=
time
.
Since
(
checkpointStart
)
resp
.
LoadDuration
=
checkpointLoaded
.
Sub
(
checkpointStart
)
if
!
req
.
Raw
{
p
,
err
:=
Prompt
(
tmpl
,
req
.
System
,
req
.
Prompt
,
generated
.
String
(),
false
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
// TODO (jmorganca): encode() should not strip special tokens
tokens
,
err
:=
runner
.
llama
.
Tokenize
(
c
.
Request
.
Context
(),
p
)
if
err
!=
nil
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
return
}
resp
.
Context
=
append
(
req
.
Context
,
tokens
...
)
}
}
ch
<-
resp
}
var
images
[]
llm
.
ImageData
for
i
:=
range
req
.
Images
{
images
=
append
(
images
,
llm
.
ImageData
{
ID
:
i
,
Data
:
req
.
Images
[
i
],
})
}
// Start prediction
req
:=
llm
.
CompletionRequest
{
Prompt
:
prompt
,
Format
:
req
.
Format
,
Images
:
images
,
Options
:
opts
,
}
if
err
:=
runner
.
llama
.
Completion
(
c
.
Request
.
Context
(),
req
,
fn
);
err
!=
nil
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
}
}
}()
}()
if
req
.
Stream
!=
nil
&&
!*
req
.
Stream
{
if
req
.
Stream
!=
nil
&&
!*
req
.
Stream
{
// Accumulate responses into the final response
var
r
api
.
GenerateResponse
var
final
api
.
GenerateResponse
var
sb
strings
.
Builder
var
sb
strings
.
Builder
for
r
esp
:=
range
ch
{
for
r
r
:=
range
ch
{
switch
r
:=
r
esp
.
(
type
)
{
switch
t
:=
r
r
.
(
type
)
{
case
api
.
GenerateResponse
:
case
api
.
GenerateResponse
:
sb
.
WriteString
(
r
.
Response
)
sb
.
WriteString
(
t
.
Response
)
final
=
r
r
=
t
case
gin
.
H
:
case
gin
.
H
:
if
errorMsg
,
ok
:=
r
[
"error"
]
.
(
string
);
ok
{
msg
,
ok
:=
t
[
"error"
]
.
(
string
)
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
errorMsg
})
if
!
ok
{
return
msg
=
"unexpected error format in response"
}
else
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"unexpected error format in response"
})
return
}
}
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
msg
})
return
default
:
default
:
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"unexpected
error
"
})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"unexpected
response
"
})
return
return
}
}
}
}
final
.
Response
=
sb
.
String
()
r
.
Response
=
sb
.
String
()
c
.
JSON
(
http
.
StatusOK
,
final
)
c
.
JSON
(
http
.
StatusOK
,
r
)
return
return
}
}
...
@@ -311,44 +248,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
...
@@ -311,44 +248,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
func
(
s
*
Server
)
EmbeddingsHandler
(
c
*
gin
.
Context
)
{
func
(
s
*
Server
)
EmbeddingsHandler
(
c
*
gin
.
Context
)
{
var
req
api
.
EmbeddingRequest
var
req
api
.
EmbeddingRequest
err
:=
c
.
ShouldBindJSON
(
&
req
)
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
errors
.
Is
(
err
,
io
.
EOF
)
{
switch
{
case
errors
.
Is
(
err
,
io
.
EOF
)
:
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"missing request body"
})
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"missing request body"
})
return
return
case
err
!=
nil
:
}
else
if
err
!=
nil
{
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
return
}
}
if
req
.
Model
==
""
{
r
,
_
,
_
,
err
:=
s
.
scheduleRunner
(
c
.
Request
.
Context
(),
req
.
Model
,
[]
Capability
{},
req
.
Options
,
req
.
KeepAlive
)
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"model is required"
})
return
}
model
,
err
:=
GetModel
(
req
.
Model
)
if
err
!=
nil
{
if
err
!=
nil
{
var
pErr
*
fs
.
PathError
handleScheduleError
(
c
,
req
.
Model
,
err
)
if
errors
.
As
(
err
,
&
pErr
)
{
c
.
JSON
(
http
.
StatusNotFound
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"model '%s' not found, try pulling it first"
,
req
.
Model
)})
return
}
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
opts
,
err
:=
modelOptions
(
model
,
req
.
Options
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
rCh
,
eCh
:=
s
.
sched
.
GetRunner
(
c
.
Request
.
Context
(),
model
,
opts
,
req
.
KeepAlive
)
var
runner
*
runnerRef
select
{
case
runner
=
<-
rCh
:
case
err
=
<-
eCh
:
handleErrorResponse
(
c
,
err
)
return
return
}
}
...
@@ -358,17 +268,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
...
@@ -358,17 +268,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
return
}
}
embedding
,
err
:=
r
unner
.
llama
.
Embedding
(
c
.
Request
.
Context
(),
req
.
Prompt
)
embedding
,
err
:=
r
.
Embedding
(
c
.
Request
.
Context
(),
req
.
Prompt
)
if
err
!=
nil
{
if
err
!=
nil
{
slog
.
Info
(
fmt
.
Sprintf
(
"embedding generation failed: %v"
,
err
))
slog
.
Info
(
fmt
.
Sprintf
(
"embedding generation failed: %v"
,
err
))
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"failed to generate embedding"
})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"failed to generate embedding"
})
return
return
}
}
resp
:=
api
.
EmbeddingResponse
{
c
.
JSON
(
http
.
StatusOK
,
api
.
EmbeddingResponse
{
Embedding
:
embedding
})
Embedding
:
embedding
,
}
c
.
JSON
(
http
.
StatusOK
,
resp
)
}
}
func
(
s
*
Server
)
PullModelHandler
(
c
*
gin
.
Context
)
{
func
(
s
*
Server
)
PullModelHandler
(
c
*
gin
.
Context
)
{
...
@@ -649,9 +556,9 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
...
@@ -649,9 +556,9 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
}
}
}
msgs
:=
make
([]
api
.
Message
,
0
)
msgs
:=
make
([]
api
.
Message
,
len
(
m
.
Messages
)
)
for
_
,
msg
:=
range
m
.
Messages
{
for
i
,
msg
:=
range
m
.
Messages
{
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
msg
.
Role
,
Content
:
msg
.
Content
}
)
msgs
[
i
]
=
api
.
Message
{
Role
:
msg
.
Role
,
Content
:
msg
.
Content
}
}
}
n
:=
model
.
ParseName
(
req
.
Model
)
n
:=
model
.
ParseName
(
req
.
Model
)
...
@@ -1214,132 +1121,55 @@ func (s *Server) ProcessHandler(c *gin.Context) {
...
@@ -1214,132 +1121,55 @@ func (s *Server) ProcessHandler(c *gin.Context) {
c
.
JSON
(
http
.
StatusOK
,
api
.
ProcessResponse
{
Models
:
models
})
c
.
JSON
(
http
.
StatusOK
,
api
.
ProcessResponse
{
Models
:
models
})
}
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func
chatPrompt
(
ctx
context
.
Context
,
runner
*
runnerRef
,
template
*
template
.
Template
,
messages
[]
api
.
Message
,
numCtx
int
)
(
string
,
error
)
{
encode
:=
func
(
s
string
)
([]
int
,
error
)
{
return
runner
.
llama
.
Tokenize
(
ctx
,
s
)
}
prompt
,
err
:=
ChatPrompt
(
template
,
messages
,
numCtx
,
encode
)
if
err
!=
nil
{
return
""
,
err
}
return
prompt
,
nil
}
func
(
s
*
Server
)
ChatHandler
(
c
*
gin
.
Context
)
{
func
(
s
*
Server
)
ChatHandler
(
c
*
gin
.
Context
)
{
checkpointStart
:=
time
.
Now
()
var
req
api
.
ChatRequest
var
req
api
.
ChatRequest
err
:=
c
.
ShouldBindJSON
(
&
req
)
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
errors
.
Is
(
err
,
io
.
EOF
)
{
switch
{
case
errors
.
Is
(
err
,
io
.
EOF
)
:
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"missing request body"
})
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"missing request body"
})
return
return
case
err
!=
nil
:
}
else
if
err
!=
nil
{
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
return
}
}
// validate the request
caps
:=
[]
Capability
{
CapabilityCompletion
}
switch
{
r
,
m
,
opts
,
err
:=
s
.
scheduleRunner
(
c
.
Request
.
Context
(),
req
.
Model
,
caps
,
req
.
Options
,
req
.
KeepAlive
)
case
req
.
Model
==
""
:
if
errors
.
Is
(
err
,
errCapabilityCompletion
)
{
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"model is required"
})
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"%q does not support chat"
,
req
.
Model
)})
return
case
len
(
req
.
Format
)
>
0
&&
req
.
Format
!=
"json"
:
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"format must be json"
})
return
}
model
,
err
:=
GetModel
(
req
.
Model
)
if
err
!=
nil
{
var
pErr
*
fs
.
PathError
if
errors
.
As
(
err
,
&
pErr
)
{
c
.
JSON
(
http
.
StatusNotFound
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"model '%s' not found, try pulling it first"
,
req
.
Model
)})
return
}
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
if
!
model
.
Has
(
CapabilityCompletion
)
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"%s does not support chat"
,
req
.
Model
)})
return
}
opts
,
err
:=
modelOptions
(
model
,
req
.
Options
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
rCh
,
eCh
:=
s
.
sched
.
GetRunner
(
c
.
Request
.
Context
(),
model
,
opts
,
req
.
KeepAlive
)
var
runner
*
runnerRef
select
{
case
runner
=
<-
rCh
:
case
err
=
<-
eCh
:
handleErrorResponse
(
c
,
err
)
return
return
}
}
else
if
err
!=
nil
{
handleScheduleError
(
c
,
req
.
Model
,
err
)
checkpointLoaded
:=
time
.
Now
()
// if the first message is not a system message, then add the model's default system message
if
len
(
req
.
Messages
)
>
0
&&
req
.
Messages
[
0
]
.
Role
!=
"system"
{
req
.
Messages
=
append
([]
api
.
Message
{
{
Role
:
"system"
,
Content
:
model
.
System
,
},
},
req
.
Messages
...
)
}
prompt
,
err
:=
chatPrompt
(
c
.
Request
.
Context
(),
runner
,
model
.
Template
,
req
.
Messages
,
opts
.
NumCtx
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
return
}
}
// an empty request loads the model
if
len
(
req
.
Messages
)
==
0
{
if
len
(
req
.
Messages
)
==
0
||
prompt
==
""
{
c
.
JSON
(
http
.
StatusOK
,
api
.
ChatResponse
{
resp
:=
api
.
ChatResponse
{
CreatedAt
:
time
.
Now
()
.
UTC
(),
Model
:
req
.
Model
,
Model
:
req
.
Model
,
CreatedAt
:
time
.
Now
()
.
UTC
(),
Message
:
api
.
Message
{
Role
:
"assistant"
},
Done
:
true
,
Done
:
true
,
DoneReason
:
"load"
,
DoneReason
:
"load"
,
Message
:
api
.
Message
{
Role
:
"assistant"
},
})
}
c
.
JSON
(
http
.
StatusOK
,
resp
)
return
return
}
}
// only send images that are in the prompt
prompt
,
images
,
err
:=
chatPrompt
(
c
.
Request
.
Context
(),
m
,
r
.
Tokenize
,
opts
,
req
.
Messages
)
var
i
int
if
err
!=
nil
{
var
images
[]
llm
.
ImageData
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
for
_
,
m
:=
range
req
.
Messages
{
return
for
_
,
img
:=
range
m
.
Images
{
if
!
isSupportedImageType
(
img
)
{
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"unsupported image format"
})
return
}
if
strings
.
Contains
(
prompt
,
fmt
.
Sprintf
(
"[img-%d]"
,
i
))
{
images
=
append
(
images
,
llm
.
ImageData
{
Data
:
img
,
ID
:
i
})
}
i
+=
1
}
}
}
slog
.
Debug
(
"chat
handler"
,
"prompt"
,
prompt
,
"images"
,
len
(
images
))
slog
.
Debug
(
"chat
request"
,
"images"
,
len
(
images
)
,
"prompt"
,
prompt
)
ch
:=
make
(
chan
any
)
ch
:=
make
(
chan
any
)
go
func
()
{
go
func
()
{
defer
close
(
ch
)
defer
close
(
ch
)
if
err
:=
r
.
Completion
(
c
.
Request
.
Context
(),
llm
.
CompletionRequest
{
fn
:=
func
(
r
llm
.
CompletionResponse
)
{
Prompt
:
prompt
,
resp
:=
api
.
ChatResponse
{
Images
:
images
,
Format
:
req
.
Format
,
Options
:
opts
,
},
func
(
r
llm
.
CompletionResponse
)
{
ch
<-
api
.
ChatResponse
{
Model
:
req
.
Model
,
Model
:
req
.
Model
,
CreatedAt
:
time
.
Now
()
.
UTC
(),
CreatedAt
:
time
.
Now
()
.
UTC
(),
Message
:
api
.
Message
{
Role
:
"assistant"
,
Content
:
r
.
Content
},
Message
:
api
.
Message
{
Role
:
"assistant"
,
Content
:
r
.
Content
},
...
@@ -1352,64 +1182,52 @@ func (s *Server) ChatHandler(c *gin.Context) {
...
@@ -1352,64 +1182,52 @@ func (s *Server) ChatHandler(c *gin.Context) {
EvalDuration
:
r
.
EvalDuration
,
EvalDuration
:
r
.
EvalDuration
,
},
},
}
}
});
err
!=
nil
{
if
r
.
Done
{
resp
.
TotalDuration
=
time
.
Since
(
checkpointStart
)
resp
.
LoadDuration
=
checkpointLoaded
.
Sub
(
checkpointStart
)
}
ch
<-
resp
}
if
err
:=
runner
.
llama
.
Completion
(
c
.
Request
.
Context
(),
llm
.
CompletionRequest
{
Prompt
:
prompt
,
Format
:
req
.
Format
,
Images
:
images
,
Options
:
opts
,
},
fn
);
err
!=
nil
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
}
}
}()
}()
if
req
.
Stream
!=
nil
&&
!*
req
.
Stream
{
if
req
.
Stream
!=
nil
&&
!*
req
.
Stream
{
// Accumulate responses into the final response
var
r
api
.
ChatResponse
var
final
api
.
ChatResponse
var
sb
strings
.
Builder
var
sb
strings
.
Builder
for
r
esp
:=
range
ch
{
for
r
r
:=
range
ch
{
switch
r
:=
r
esp
.
(
type
)
{
switch
t
:=
r
r
.
(
type
)
{
case
api
.
ChatResponse
:
case
api
.
ChatResponse
:
sb
.
WriteString
(
r
.
Message
.
Content
)
sb
.
WriteString
(
t
.
Message
.
Content
)
final
=
r
r
=
t
case
gin
.
H
:
case
gin
.
H
:
if
errorMsg
,
ok
:=
r
[
"error"
]
.
(
string
);
ok
{
msg
,
ok
:=
t
[
"error"
]
.
(
string
)
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
errorMsg
})
if
!
ok
{
return
msg
=
"unexpected error format in response"
}
else
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"unexpected error format in response"
})
return
}
}
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
msg
})
return
default
:
default
:
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"unexpected
error
"
})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"unexpected
response
"
})
return
return
}
}
}
}
final
.
Message
=
api
.
Message
{
Role
:
"assistant"
,
Content
:
sb
.
String
()
}
r
.
Message
.
Content
=
sb
.
String
()
c
.
JSON
(
http
.
StatusOK
,
final
)
c
.
JSON
(
http
.
StatusOK
,
r
)
return
return
}
}
streamResponse
(
c
,
ch
)
streamResponse
(
c
,
ch
)
}
}
func
handleErrorResponse
(
c
*
gin
.
Context
,
err
error
)
{
func
handleScheduleError
(
c
*
gin
.
Context
,
name
string
,
err
error
)
{
if
errors
.
Is
(
err
,
context
.
Canceled
)
{
switch
{
case
errors
.
Is
(
err
,
errRequired
)
:
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
case
errors
.
Is
(
err
,
context
.
Canceled
)
:
c
.
JSON
(
499
,
gin
.
H
{
"error"
:
"request canceled"
})
c
.
JSON
(
499
,
gin
.
H
{
"error"
:
"request canceled"
})
return
case
errors
.
Is
(
err
,
ErrMaxQueue
)
:
}
if
errors
.
Is
(
err
,
ErrMaxQueue
)
{
c
.
JSON
(
http
.
StatusServiceUnavailable
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusServiceUnavailable
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
case
errors
.
Is
(
err
,
os
.
ErrNotExist
)
:
c
.
JSON
(
http
.
StatusNotFound
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"model %q not found, try pulling it first"
,
name
)})
default
:
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
}
}
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
}
}
template/template.go
View file @
9bbddc37
...
@@ -5,6 +5,7 @@ import (
...
@@ -5,6 +5,7 @@ import (
"embed"
"embed"
"encoding/json"
"encoding/json"
"errors"
"errors"
"fmt"
"io"
"io"
"math"
"math"
"slices"
"slices"
...
@@ -14,6 +15,7 @@ import (
...
@@ -14,6 +15,7 @@ import (
"text/template/parse"
"text/template/parse"
"github.com/agnivade/levenshtein"
"github.com/agnivade/levenshtein"
"github.com/ollama/ollama/api"
"golang.org/x/exp/maps"
"golang.org/x/exp/maps"
)
)
...
@@ -74,30 +76,59 @@ func Named(s string) (*named, error) {
...
@@ -74,30 +76,59 @@ func Named(s string) (*named, error) {
return
nil
,
errors
.
New
(
"no matching template found"
)
return
nil
,
errors
.
New
(
"no matching template found"
)
}
}
var
DefaultTemplate
,
_
=
Parse
(
"{{ .Prompt }}"
)
type
Template
struct
{
type
Template
struct
{
*
template
.
Template
*
template
.
Template
raw
string
raw
string
}
}
func
(
t
*
Template
)
String
()
string
{
// response is a template node that can be added to templates that don't already have one
return
t
.
raw
var
response
=
parse
.
ActionNode
{
NodeType
:
parse
.
NodeAction
,
Pipe
:
&
parse
.
PipeNode
{
NodeType
:
parse
.
NodePipe
,
Cmds
:
[]
*
parse
.
CommandNode
{
{
NodeType
:
parse
.
NodeCommand
,
Args
:
[]
parse
.
Node
{
&
parse
.
FieldNode
{
NodeType
:
parse
.
NodeField
,
Ident
:
[]
string
{
"Response"
},
},
},
},
},
},
}
}
var
DefaultTemplate
,
_
=
Parse
(
"{{ .Prompt }}"
)
func
Parse
(
s
string
)
(
*
Template
,
error
)
{
func
Parse
(
s
string
)
(
*
Template
,
error
)
{
t
,
err
:=
template
.
New
(
""
)
.
Option
(
"missingkey=zero"
)
.
Parse
(
s
)
tmpl
:=
template
.
New
(
""
)
.
Option
(
"missingkey=zero"
)
tmpl
,
err
:=
tmpl
.
Parse
(
s
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
return
&
Template
{
Template
:
t
,
raw
:
s
},
nil
t
:=
Template
{
Template
:
tmpl
,
raw
:
s
}
if
vars
:=
t
.
Vars
();
!
slices
.
Contains
(
vars
,
"messages"
)
&&
!
slices
.
Contains
(
vars
,
"response"
)
{
// touch up the template and append {{ .Response }}
tmpl
.
Tree
.
Root
.
Nodes
=
append
(
tmpl
.
Tree
.
Root
.
Nodes
,
&
response
)
}
return
&
t
,
nil
}
func
(
t
*
Template
)
String
()
string
{
return
t
.
raw
}
}
func
(
t
*
Template
)
Vars
()
[]
string
{
func
(
t
*
Template
)
Vars
()
[]
string
{
var
vars
[]
string
var
vars
[]
string
for
_
,
n
:=
range
t
.
Tree
.
Root
.
Nodes
{
for
_
,
tt
:=
range
t
.
Templates
()
{
vars
=
append
(
vars
,
parseNode
(
n
)
...
)
for
_
,
n
:=
range
tt
.
Root
.
Nodes
{
vars
=
append
(
vars
,
parseNode
(
n
)
...
)
}
}
}
set
:=
make
(
map
[
string
]
struct
{})
set
:=
make
(
map
[
string
]
struct
{})
...
@@ -110,6 +141,103 @@ func (t *Template) Vars() []string {
...
@@ -110,6 +141,103 @@ func (t *Template) Vars() []string {
return
vars
return
vars
}
}
type
Values
struct
{
Messages
[]
api
.
Message
}
func
(
t
*
Template
)
Execute
(
w
io
.
Writer
,
v
Values
)
error
{
system
,
collated
:=
collate
(
v
.
Messages
)
if
slices
.
Contains
(
t
.
Vars
(),
"messages"
)
{
return
t
.
Template
.
Execute
(
w
,
map
[
string
]
any
{
"System"
:
system
,
"Messages"
:
collated
,
})
}
var
b
bytes
.
Buffer
var
prompt
,
response
string
for
i
,
m
:=
range
collated
{
if
m
.
Role
==
"user"
{
prompt
=
m
.
Content
}
else
{
response
=
m
.
Content
}
if
i
!=
len
(
collated
)
-
1
&&
prompt
!=
""
&&
response
!=
""
{
if
err
:=
t
.
Template
.
Execute
(
&
b
,
map
[
string
]
any
{
"System"
:
""
,
"Prompt"
:
prompt
,
"Response"
:
response
,
});
err
!=
nil
{
return
err
}
prompt
=
""
response
=
""
}
}
var
cut
bool
tree
:=
t
.
Template
.
Copy
()
// for the last message, cut everything after "{{ .Response }}"
tree
.
Root
.
Nodes
=
slices
.
DeleteFunc
(
tree
.
Root
.
Nodes
,
func
(
n
parse
.
Node
)
bool
{
if
slices
.
Contains
(
parseNode
(
n
),
"Response"
)
{
cut
=
true
}
return
cut
})
if
err
:=
template
.
Must
(
template
.
New
(
""
)
.
AddParseTree
(
""
,
tree
))
.
Execute
(
&
b
,
map
[
string
]
any
{
"System"
:
system
,
"Prompt"
:
prompt
,
});
err
!=
nil
{
return
err
}
_
,
err
:=
io
.
Copy
(
w
,
&
b
)
return
err
}
type
messages
[]
*
api
.
Message
// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also pulls out and merges messages with Role == "system"
// which are templated separately. As a side effect, it mangles message content adding image
// tags ([img-%d]) as needed
func
collate
(
msgs
[]
api
.
Message
)
(
system
string
,
collated
messages
)
{
var
n
int
for
i
:=
range
msgs
{
msg
:=
msgs
[
i
]
if
msg
.
Role
==
"system"
{
if
system
!=
""
{
system
+=
"
\n\n
"
}
system
+=
msg
.
Content
continue
}
for
range
msg
.
Images
{
imageTag
:=
fmt
.
Sprintf
(
"[img-%d]"
,
n
)
if
!
strings
.
Contains
(
msg
.
Content
,
"[img]"
)
{
msg
.
Content
=
strings
.
TrimSpace
(
"[img] "
+
msg
.
Content
)
}
msg
.
Content
=
strings
.
Replace
(
msg
.
Content
,
"[img]"
,
imageTag
,
1
)
n
++
}
if
len
(
collated
)
>
0
&&
collated
[
len
(
collated
)
-
1
]
.
Role
==
msg
.
Role
{
collated
[
len
(
collated
)
-
1
]
.
Content
+=
"
\n\n
"
+
msg
.
Content
}
else
{
collated
=
append
(
collated
,
&
msg
)
}
}
return
}
func
parseNode
(
n
parse
.
Node
)
[]
string
{
func
parseNode
(
n
parse
.
Node
)
[]
string
{
switch
n
:=
n
.
(
type
)
{
switch
n
:=
n
.
(
type
)
{
case
*
parse
.
ActionNode
:
case
*
parse
.
ActionNode
:
...
@@ -152,6 +280,8 @@ func parseNode(n parse.Node) []string {
...
@@ -152,6 +280,8 @@ func parseNode(n parse.Node) []string {
return
names
return
names
case
*
parse
.
FieldNode
:
case
*
parse
.
FieldNode
:
return
n
.
Ident
return
n
.
Ident
case
*
parse
.
TemplateNode
:
return
parseNode
(
n
.
Pipe
)
}
}
return
nil
return
nil
...
...
template/template_test.go
View file @
9bbddc37
...
@@ -11,6 +11,7 @@ import (
...
@@ -11,6 +11,7 @@ import (
"testing"
"testing"
"text/template"
"text/template"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/llm"
)
)
...
@@ -64,13 +65,12 @@ func TestParse(t *testing.T) {
...
@@ -64,13 +65,12 @@ func TestParse(t *testing.T) {
template
string
template
string
vars
[]
string
vars
[]
string
}{
}{
{
"{{ .Prompt }}"
,
[]
string
{
"prompt"
}},
{
"{{ .Prompt }}"
,
[]
string
{
"prompt"
,
"response"
}},
{
"{{ .System }} {{ .Prompt }}"
,
[]
string
{
"prompt"
,
"system"
}},
{
"{{ .System }} {{ .Prompt }}"
,
[]
string
{
"prompt"
,
"response"
,
"system"
}},
{
"{{ .System }} {{ .Prompt }} {{ .Response }}"
,
[]
string
{
"prompt"
,
"response"
,
"system"
}},
{
"{{ .System }} {{ .Prompt }} {{ .Response }}"
,
[]
string
{
"prompt"
,
"response"
,
"system"
}},
{
"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}"
,
[]
string
{
"prompt"
,
"system"
,
"tools"
}},
{
"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}"
,
[]
string
{
"prompt"
,
"response"
,
"system"
,
"tools"
}},
{
"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}"
,
[]
string
{
"content"
,
"messages"
,
"role"
}},
{
"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}"
,
[]
string
{
"content"
,
"messages"
,
"role"
}},
{
"{{ range .Messages }}{{ if eq .Role
\"
system
\"
}}SYSTEM: {{ .Content }}{{ else if eq .Role
\"
user
\"
}}USER: {{ .Content }}{{ else if eq .Role
\"
assistant
\"
}}ASSISTANT: {{ .Content }}{{ end }}{{ end }}"
,
[]
string
{
"content"
,
"messages"
,
"role"
}},
{
"{{ range .Messages }}{{ if eq .Role
\"
system
\"
}}SYSTEM: {{ .Content }}{{ else if eq .Role
\"
user
\"
}}USER: {{ .Content }}{{ else if eq .Role
\"
assistant
\"
}}ASSISTANT: {{ .Content }}{{ end }}{{ end }}"
,
[]
string
{
"content"
,
"messages"
,
"role"
}},
{
"{{ .Prompt }} {{ .Suffix }}"
,
[]
string
{
"prompt"
,
"suffix"
}},
}
}
for
_
,
tt
:=
range
cases
{
for
_
,
tt
:=
range
cases
{
...
@@ -87,3 +87,159 @@ func TestParse(t *testing.T) {
...
@@ -87,3 +87,159 @@ func TestParse(t *testing.T) {
})
})
}
}
}
}
func
TestExecuteWithMessages
(
t
*
testing
.
T
)
{
type
template
struct
{
name
string
template
string
}
cases
:=
[]
struct
{
name
string
templates
[]
template
values
Values
expected
string
}{
{
"mistral"
,
[]
template
{
{
"no response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `
},
{
"response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{
"messages"
,
`{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}`
},
},
Values
{
Messages
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"Hello friend!"
},
{
Role
:
"assistant"
,
Content
:
"Hello human!"
},
{
Role
:
"user"
,
Content
:
"What is your name?"
},
},
},
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `
,
},
{
"mistral system"
,
[]
template
{
{
"no response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `
},
{
"response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{
"messages"
,
`
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}`
},
},
Values
{
Messages
:
[]
api
.
Message
{
{
Role
:
"system"
,
Content
:
"You are a helpful assistant!"
},
{
Role
:
"user"
,
Content
:
"Hello friend!"
},
{
Role
:
"assistant"
,
Content
:
"Hello human!"
},
{
Role
:
"user"
,
Content
:
"What is your name?"
},
},
},
`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
What is your name?[/INST] `
,
},
{
"chatml"
,
[]
template
{
// this does not have a "no response" test because it's impossible to render the same output
{
"response"
,
`{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
`
},
{
"messages"
,
`
{{- range $index, $_ := .Messages }}
{{- if and (eq .Role "user") (eq (len (slice $.Messages $index)) 1) $.System }}<|im_start|>system
{{ $.System }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>assistant
`
},
},
Values
{
Messages
:
[]
api
.
Message
{
{
Role
:
"system"
,
Content
:
"You are a helpful assistant!"
},
{
Role
:
"user"
,
Content
:
"Hello friend!"
},
{
Role
:
"assistant"
,
Content
:
"Hello human!"
},
{
Role
:
"user"
,
Content
:
"What is your name?"
},
},
},
`<|im_start|>user
Hello friend!<|im_end|>
<|im_start|>assistant
Hello human!<|im_end|>
<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
What is your name?<|im_end|>
<|im_start|>assistant
`
,
},
{
"moondream"
,
[]
template
{
// this does not have a "no response" test because it's impossible to render the same output
{
"response"
,
`{{ if .Prompt }}Question: {{ .Prompt }}
{{ end }}Answer: {{ .Response }}
`
},
{
"messages"
,
`
{{- range .Messages }}
{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
{{- end }}
{{- end }}Answer: `
},
},
Values
{
Messages
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"What's in this image?"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
""
)}},
{
Role
:
"assistant"
,
Content
:
"It's a hot dog."
},
{
Role
:
"user"
,
Content
:
"What's in _this_ image?"
},
{
Role
:
"user"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
""
)}},
{
Role
:
"user"
,
Content
:
"Is it a hot dog?"
},
},
},
`Question: [img-0] What's in this image?
Answer: It's a hot dog.
Question: What's in _this_ image?
[img-1]
Is it a hot dog?
Answer: `
,
},
}
for
_
,
tt
:=
range
cases
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
for
_
,
ttt
:=
range
tt
.
templates
{
t
.
Run
(
ttt
.
name
,
func
(
t
*
testing
.
T
)
{
tmpl
,
err
:=
Parse
(
ttt
.
template
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
var
b
bytes
.
Buffer
if
err
:=
tmpl
.
Execute
(
&
b
,
tt
.
values
);
err
!=
nil
{
t
.
Fatal
(
err
)
}
if
b
.
String
()
!=
tt
.
expected
{
t
.
Errorf
(
"expected
\n
%s,
\n
got
\n
%s"
,
tt
.
expected
,
b
.
String
())
}
})
}
})
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment