Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
2c3fe1fd
You need to sign in or sign up before continuing.
Commit
2c3fe1fd
authored
Jun 20, 2024
by
Michael Yang
Browse files
comments
parent
269ed6e6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
223 additions
and
112 deletions
+223
-112
server/prompt.go
server/prompt.go
+19
-10
server/prompt_test.go
server/prompt_test.go
+12
-22
server/routes.go
server/routes.go
+25
-21
template/template.go
template/template.go
+25
-21
template/template_test.go
template/template_test.go
+142
-38
No files found.
server/prompt.go
View file @
2c3fe1fd
...
@@ -11,8 +11,13 @@ import (
...
@@ -11,8 +11,13 @@ import (
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/template"
)
)
func
chatPrompt
(
ctx
context
.
Context
,
r
*
runnerRef
,
msgs
[]
api
.
Message
)
(
prompt
string
,
images
[]
llm
.
ImageData
,
_
error
)
{
type
tokenizeFunc
func
(
context
.
Context
,
string
)
([]
int
,
error
)
// extract system messages which should always be included
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func
chatPrompt
(
ctx
context
.
Context
,
m
*
Model
,
tokenize
tokenizeFunc
,
opts
*
api
.
Options
,
msgs
[]
api
.
Message
)
(
prompt
string
,
images
[]
llm
.
ImageData
,
_
error
)
{
// pull out any system messages which should always be included in the prompt
var
system
[]
api
.
Message
var
system
[]
api
.
Message
msgs
=
slices
.
DeleteFunc
(
msgs
,
func
(
m
api
.
Message
)
bool
{
msgs
=
slices
.
DeleteFunc
(
msgs
,
func
(
m
api
.
Message
)
bool
{
if
m
.
Role
==
"system"
{
if
m
.
Role
==
"system"
{
...
@@ -23,32 +28,35 @@ func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt s
...
@@ -23,32 +28,35 @@ func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt s
return
false
return
false
})
})
if
len
(
system
)
==
0
&&
r
.
model
.
System
!=
""
{
if
len
(
system
)
==
0
&&
m
.
System
!=
""
{
// add model system prompt since it wasn't provided
// add model system prompt since it wasn't provided
system
=
append
(
system
,
api
.
Message
{
Role
:
"system"
,
Content
:
r
.
model
.
System
})
system
=
append
(
system
,
api
.
Message
{
Role
:
"system"
,
Content
:
m
.
System
})
}
}
// always include the last message
n
:=
len
(
msgs
)
-
1
n
:=
len
(
msgs
)
-
1
// in reverse, find all messages that fit into context window
for
i
:=
n
-
1
;
i
>=
0
;
i
--
{
for
i
:=
n
-
1
;
i
>=
0
;
i
--
{
var
b
bytes
.
Buffer
var
b
bytes
.
Buffer
if
err
:=
r
.
model
.
Template
.
Execute
(
&
b
,
template
.
Values
{
Messages
:
append
(
system
,
msgs
[
i
:
]
...
)});
err
!=
nil
{
if
err
:=
m
.
Template
.
Execute
(
&
b
,
template
.
Values
{
Messages
:
append
(
system
,
msgs
[
i
:
]
...
)});
err
!=
nil
{
return
""
,
nil
,
err
return
""
,
nil
,
err
}
}
s
,
err
:=
r
.
llama
.
T
okenize
(
ctx
,
b
.
String
())
s
,
err
:=
t
okenize
(
ctx
,
b
.
String
())
if
err
!=
nil
{
if
err
!=
nil
{
return
""
,
nil
,
err
return
""
,
nil
,
err
}
}
c
:=
len
(
s
)
c
:=
len
(
s
)
if
r
.
model
.
ProjectorPaths
!=
nil
{
if
m
.
ProjectorPaths
!=
nil
{
for
_
,
m
:=
range
msgs
[
i
:
]
{
for
_
,
m
:=
range
msgs
[
i
:
]
{
// TODO: get image embedding length from project metadata
// images are represented as 768 sized embeddings
// TODO: get embedding length from project metadata
c
+=
768
*
len
(
m
.
Images
)
c
+=
768
*
len
(
m
.
Images
)
}
}
}
}
if
c
>
r
.
NumCtx
{
if
c
>
opts
.
NumCtx
{
slog
.
Debug
(
"truncating input messages which exceed context length"
,
"truncated"
,
len
(
msgs
[
i
:
]))
slog
.
Debug
(
"truncating input messages which exceed context length"
,
"truncated"
,
len
(
msgs
[
i
:
]))
break
break
}
else
{
}
else
{
...
@@ -56,8 +64,9 @@ func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt s
...
@@ -56,8 +64,9 @@ func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt s
}
}
}
}
// truncate any messages that do not fit into the context window
var
b
bytes
.
Buffer
var
b
bytes
.
Buffer
if
err
:=
r
.
model
.
Template
.
Execute
(
&
b
,
template
.
Values
{
Messages
:
append
(
system
,
msgs
[
n
:
]
...
)});
err
!=
nil
{
if
err
:=
m
.
Template
.
Execute
(
&
b
,
template
.
Values
{
Messages
:
append
(
system
,
msgs
[
n
:
]
...
)});
err
!=
nil
{
return
""
,
nil
,
err
return
""
,
nil
,
err
}
}
...
...
server/prompt_test.go
View file @
2c3fe1fd
...
@@ -7,15 +7,10 @@ import (
...
@@ -7,15 +7,10 @@ import (
"testing"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/template"
)
)
type
mock
struct
{
func
tokenize
(
_
context
.
Context
,
s
string
)
(
tokens
[]
int
,
err
error
)
{
llm
.
LlamaServer
}
func
(
m
mock
)
Tokenize
(
_
context
.
Context
,
s
string
)
(
tokens
[]
int
,
err
error
)
{
for
range
strings
.
Fields
(
s
)
{
for
range
strings
.
Fields
(
s
)
{
tokens
=
append
(
tokens
,
len
(
tokens
))
tokens
=
append
(
tokens
,
len
(
tokens
))
}
}
...
@@ -48,7 +43,7 @@ func TestChatPrompt(t *testing.T) {
...
@@ -48,7 +43,7 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"truncate messages"
,
name
:
"truncate messages"
,
limit
:
1
,
limit
:
1
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
...
@@ -60,7 +55,7 @@ func TestChatPrompt(t *testing.T) {
...
@@ -60,7 +55,7 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"truncate messages with image"
,
name
:
"truncate messages with image"
,
limit
:
64
,
limit
:
64
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
...
@@ -75,7 +70,7 @@ func TestChatPrompt(t *testing.T) {
...
@@ -75,7 +70,7 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"truncate messages with images"
,
name
:
"truncate messages with images"
,
limit
:
64
,
limit
:
64
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
...
@@ -90,7 +85,7 @@ func TestChatPrompt(t *testing.T) {
...
@@ -90,7 +85,7 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"messages with images"
,
name
:
"messages with images"
,
limit
:
2048
,
limit
:
2048
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
...
@@ -106,7 +101,7 @@ func TestChatPrompt(t *testing.T) {
...
@@ -106,7 +101,7 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"message with image tag"
,
name
:
"message with image tag"
,
limit
:
2048
,
limit
:
2048
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry! [img]"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Content
:
"You're a test, Harry! [img]"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
...
@@ -122,7 +117,7 @@ func TestChatPrompt(t *testing.T) {
...
@@ -122,7 +117,7 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"messages with interleaved images"
,
name
:
"messages with interleaved images"
,
limit
:
2048
,
limit
:
2048
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
...
@@ -140,7 +135,7 @@ func TestChatPrompt(t *testing.T) {
...
@@ -140,7 +135,7 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"truncate message with interleaved images"
,
name
:
"truncate message with interleaved images"
,
limit
:
1024
,
limit
:
1024
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
...
@@ -157,7 +152,7 @@ func TestChatPrompt(t *testing.T) {
...
@@ -157,7 +152,7 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"message with system prompt"
,
name
:
"message with system prompt"
,
limit
:
2048
,
limit
:
2048
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"system"
,
Content
:
"You are the Test Who Lived."
},
{
Role
:
"system"
,
Content
:
"You are the Test Who Lived."
},
...
@@ -181,14 +176,9 @@ func TestChatPrompt(t *testing.T) {
...
@@ -181,14 +176,9 @@ func TestChatPrompt(t *testing.T) {
for
_
,
tt
:=
range
cases
{
for
_
,
tt
:=
range
cases
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
r
:=
runnerRef
{
model
:=
Model
{
Template
:
tmpl
,
ProjectorPaths
:
[]
string
{
"vision"
}}
llama
:
mock
{},
opts
:=
api
.
Options
{
Runner
:
api
.
Runner
{
NumCtx
:
tt
.
limit
}}
model
:
&
Model
{
Template
:
tmpl
,
ProjectorPaths
:
[]
string
{
"vision"
}},
prompt
,
images
,
err
:=
chatPrompt
(
context
.
TODO
(),
&
model
,
tokenize
,
&
opts
,
tt
.
msgs
)
Options
:
&
api
.
Options
{},
}
r
.
NumCtx
=
tt
.
limit
prompt
,
images
,
err
:=
chatPrompt
(
context
.
TODO
(),
&
r
,
tt
.
msgs
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Fatal
(
err
)
t
.
Fatal
(
err
)
}
}
...
...
server/routes.go
View file @
2c3fe1fd
...
@@ -54,6 +54,8 @@ func init() {
...
@@ -54,6 +54,8 @@ func init() {
gin
.
SetMode
(
mode
)
gin
.
SetMode
(
mode
)
}
}
var
errRequired
=
errors
.
New
(
"is required"
)
func
modelOptions
(
model
*
Model
,
requestOpts
map
[
string
]
interface
{})
(
api
.
Options
,
error
)
{
func
modelOptions
(
model
*
Model
,
requestOpts
map
[
string
]
interface
{})
(
api
.
Options
,
error
)
{
opts
:=
api
.
DefaultOptions
()
opts
:=
api
.
DefaultOptions
()
if
err
:=
opts
.
FromMap
(
model
.
Options
);
err
!=
nil
{
if
err
:=
opts
.
FromMap
(
model
.
Options
);
err
!=
nil
{
...
@@ -69,7 +71,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
...
@@ -69,7 +71,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
func
(
s
*
Server
)
scheduleRunner
(
ctx
context
.
Context
,
name
string
,
caps
[]
Capability
,
requestOpts
map
[
string
]
any
,
keepAlive
*
api
.
Duration
)
(
*
runnerRef
,
error
)
{
func
(
s
*
Server
)
scheduleRunner
(
ctx
context
.
Context
,
name
string
,
caps
[]
Capability
,
requestOpts
map
[
string
]
any
,
keepAlive
*
api
.
Duration
)
(
*
runnerRef
,
error
)
{
if
name
==
""
{
if
name
==
""
{
return
nil
,
errors
.
New
(
"model is r
equired
"
)
return
nil
,
fmt
.
Errorf
(
"model %w"
,
errR
equired
)
}
}
model
,
err
:=
GetModel
(
name
)
model
,
err
:=
GetModel
(
name
)
...
@@ -121,7 +123,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
...
@@ -121,7 +123,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"%q does not support generate"
,
req
.
Model
)})
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"%q does not support generate"
,
req
.
Model
)})
return
return
}
else
if
err
!=
nil
{
}
else
if
err
!=
nil
{
handleScheduleError
(
c
,
err
)
handleScheduleError
(
c
,
req
.
Model
,
err
)
return
}
if
req
.
Prompt
==
""
{
c
.
JSON
(
http
.
StatusOK
,
api
.
GenerateResponse
{
Model
:
req
.
Model
,
CreatedAt
:
time
.
Now
()
.
UTC
(),
Done
:
true
,
DoneReason
:
"load"
,
})
return
return
}
}
...
@@ -139,23 +151,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
...
@@ -139,23 +151,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"system"
,
Content
:
r
.
model
.
System
})
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"system"
,
Content
:
r
.
model
.
System
})
}
}
if
req
.
Prompt
!=
""
{
for
_
,
i
:=
range
images
{
for
_
,
i
:=
range
images
{
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"user"
,
Content
:
fmt
.
Sprintf
(
"[img-%d]"
,
i
.
ID
)})
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"user"
,
Content
:
fmt
.
Sprintf
(
"[img-%d]"
,
i
.
ID
)})
}
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"user"
,
Content
:
req
.
Prompt
})
}
}
if
len
(
msgs
)
==
0
{
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"user"
,
Content
:
req
.
Prompt
})
c
.
JSON
(
http
.
StatusOK
,
api
.
GenerateResponse
{
Model
:
req
.
Model
,
CreatedAt
:
time
.
Now
()
.
UTC
(),
Done
:
true
,
DoneReason
:
"load"
,
})
return
}
tmpl
:=
r
.
model
.
Template
tmpl
:=
r
.
model
.
Template
if
req
.
Template
!=
""
{
if
req
.
Template
!=
""
{
...
@@ -256,7 +256,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
...
@@ -256,7 +256,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
r
,
err
:=
s
.
scheduleRunner
(
c
.
Request
.
Context
(),
req
.
Model
,
[]
Capability
{},
req
.
Options
,
req
.
KeepAlive
)
r
,
err
:=
s
.
scheduleRunner
(
c
.
Request
.
Context
(),
req
.
Model
,
[]
Capability
{},
req
.
Options
,
req
.
KeepAlive
)
if
err
!=
nil
{
if
err
!=
nil
{
handleScheduleError
(
c
,
err
)
handleScheduleError
(
c
,
req
.
Model
,
err
)
return
return
}
}
...
@@ -1135,7 +1135,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
...
@@ -1135,7 +1135,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"%q does not support chat"
,
req
.
Model
)})
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"%q does not support chat"
,
req
.
Model
)})
return
return
}
else
if
err
!=
nil
{
}
else
if
err
!=
nil
{
handleScheduleError
(
c
,
err
)
handleScheduleError
(
c
,
req
.
Model
,
err
)
return
return
}
}
...
@@ -1150,7 +1150,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
...
@@ -1150,7 +1150,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
return
}
}
prompt
,
images
,
err
:=
chatPrompt
(
c
.
Request
.
Context
(),
r
,
req
.
Messages
)
prompt
,
images
,
err
:=
chatPrompt
(
c
.
Request
.
Context
(),
r
.
model
,
r
.
llama
.
Tokenize
,
r
.
Options
,
req
.
Messages
)
if
err
!=
nil
{
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
return
...
@@ -1215,12 +1215,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
...
@@ -1215,12 +1215,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
streamResponse
(
c
,
ch
)
streamResponse
(
c
,
ch
)
}
}
func
handleScheduleError
(
c
*
gin
.
Context
,
err
error
)
{
func
handleScheduleError
(
c
*
gin
.
Context
,
name
string
,
err
error
)
{
switch
{
switch
{
case
errors
.
Is
(
err
,
errRequired
)
:
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
case
errors
.
Is
(
err
,
context
.
Canceled
)
:
case
errors
.
Is
(
err
,
context
.
Canceled
)
:
c
.
JSON
(
499
,
gin
.
H
{
"error"
:
"request canceled"
})
c
.
JSON
(
499
,
gin
.
H
{
"error"
:
"request canceled"
})
case
errors
.
Is
(
err
,
ErrMaxQueue
)
:
case
errors
.
Is
(
err
,
ErrMaxQueue
)
:
c
.
JSON
(
http
.
StatusServiceUnavailable
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusServiceUnavailable
,
gin
.
H
{
"error"
:
err
.
Error
()})
case
errors
.
Is
(
err
,
os
.
ErrNotExist
)
:
c
.
JSON
(
http
.
StatusNotFound
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"model %q not found, try pulling it first"
,
name
)})
default
:
default
:
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
}
}
...
...
template/template.go
View file @
2c3fe1fd
...
@@ -83,6 +83,7 @@ type Template struct {
...
@@ -83,6 +83,7 @@ type Template struct {
raw
string
raw
string
}
}
// response is a template node that can be added to templates that don't already have one
var
response
=
parse
.
ActionNode
{
var
response
=
parse
.
ActionNode
{
NodeType
:
parse
.
NodeAction
,
NodeType
:
parse
.
NodeAction
,
Pipe
:
&
parse
.
PipeNode
{
Pipe
:
&
parse
.
PipeNode
{
...
@@ -101,28 +102,25 @@ var response = parse.ActionNode{
...
@@ -101,28 +102,25 @@ var response = parse.ActionNode{
},
},
}
}
func
Parse
(
s
string
)
(
*
Template
,
error
)
{
var
funcs
=
template
.
FuncMap
{
tmpl
:=
template
.
New
(
""
)
.
Option
(
"missingkey=zero"
)
.
Funcs
(
template
.
FuncMap
{
"toJson"
:
func
(
v
any
)
string
{
"toJson"
:
func
(
v
any
)
string
{
b
,
err
:=
json
.
Marshal
(
v
)
b
,
err
:=
json
.
Marshal
(
v
)
if
err
!=
nil
{
if
err
!=
nil
{
return
""
return
""
}
}
return
string
(
b
)
},
"isLastMessage"
:
func
(
s
[]
*
api
.
Message
,
m
*
api
.
Message
)
bool
{
for
i
:=
len
(
s
)
-
1
;
i
>=
0
;
i
--
{
if
m
.
Role
!=
s
[
i
]
.
Role
{
continue
}
return
m
==
s
[
i
]
return
string
(
b
)
}
},
"add"
:
func
(
a
,
b
int
)
int
{
return
a
+
b
},
"sub"
:
func
(
a
,
b
int
)
int
{
return
a
-
b
},
}
return
false
func
Parse
(
s
string
)
(
*
Template
,
error
)
{
},
tmpl
:=
template
.
New
(
""
)
.
Option
(
"missingkey=zero"
)
.
Funcs
(
funcs
)
})
tmpl
,
err
:=
tmpl
.
Parse
(
s
)
tmpl
,
err
:=
tmpl
.
Parse
(
s
)
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -218,7 +216,13 @@ func (t *Template) Execute(w io.Writer, v Values) error {
...
@@ -218,7 +216,13 @@ func (t *Template) Execute(w io.Writer, v Values) error {
return
err
return
err
}
}
func
collate
(
msgs
[]
api
.
Message
)
(
system
string
,
collated
[]
*
api
.
Message
)
{
type
messages
[]
*
api
.
Message
// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also pulls out and merges messages with Role == "system"
// which are templated separately. As a side effect, it mangles message content adding image
// tags ([img-%d]) as needed
func
collate
(
msgs
[]
api
.
Message
)
(
system
string
,
collated
messages
)
{
var
n
int
var
n
int
for
i
:=
range
msgs
{
for
i
:=
range
msgs
{
msg
:=
msgs
[
i
]
msg
:=
msgs
[
i
]
...
...
template/template_test.go
View file @
2c3fe1fd
...
@@ -8,6 +8,7 @@ import (
...
@@ -8,6 +8,7 @@ import (
"os"
"os"
"path/filepath"
"path/filepath"
"slices"
"slices"
"strconv"
"testing"
"testing"
"text/template"
"text/template"
...
@@ -15,6 +16,98 @@ import (
...
@@ -15,6 +16,98 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/llm"
)
)
func
TestFuncs
(
t
*
testing
.
T
)
{
t
.
Run
(
"toJson"
,
func
(
t
*
testing
.
T
)
{
cases
:=
[]
struct
{
input
any
expected
string
}{
{
nil
,
"null"
},
{
true
,
"true"
},
{
false
,
"false"
},
{
0
,
"0"
},
{
1
,
"1"
},
{
1.0
,
"1"
},
{
1.1
,
"1.1"
},
{
""
,
`""`
},
{
"hello"
,
`"hello"`
},
{[]
int
{
1
,
2
,
3
},
"[1,2,3]"
},
{[]
string
{
"a"
,
"b"
,
"c"
},
`["a","b","c"]`
},
{
map
[
string
]
int
{
"a"
:
1
,
"b"
:
2
},
`{"a":1,"b":2}`
},
{
map
[
string
]
string
{
"a"
:
"b"
,
"c"
:
"d"
},
`{"a":"b","c":"d"}`
},
}
for
_
,
tt
:=
range
cases
{
t
.
Run
(
tt
.
expected
,
func
(
t
*
testing
.
T
)
{
toJson
,
ok
:=
funcs
[
"toJson"
]
.
(
func
(
any
)
string
)
if
!
ok
{
t
.
Fatal
(
"toJson is not a function"
)
}
if
s
:=
toJson
(
tt
.
input
);
s
!=
tt
.
expected
{
t
.
Errorf
(
"expected %q, got %q"
,
tt
.
expected
,
s
)
}
})
}
})
t
.
Run
(
"add"
,
func
(
t
*
testing
.
T
)
{
cases
:=
[]
struct
{
a
,
b
int
expected
int
}{
{
0
,
0
,
0
},
{
0
,
1
,
1
},
{
1
,
0
,
1
},
{
1
,
1
,
2
},
{
1
,
-
1
,
0
},
{
-
1
,
1
,
0
},
{
-
1
,
-
1
,
-
2
},
}
for
_
,
tt
:=
range
cases
{
t
.
Run
(
strconv
.
Itoa
(
tt
.
expected
),
func
(
t
*
testing
.
T
)
{
add
,
ok
:=
funcs
[
"add"
]
.
(
func
(
int
,
int
)
int
)
if
!
ok
{
t
.
Fatal
(
"add is not a function"
)
}
if
n
:=
add
(
tt
.
a
,
tt
.
b
);
n
!=
tt
.
expected
{
t
.
Errorf
(
"expected %d, got %d"
,
tt
.
expected
,
n
)
}
})
}
})
t
.
Run
(
"sub"
,
func
(
t
*
testing
.
T
)
{
cases
:=
[]
struct
{
a
,
b
int
expected
int
}{
{
0
,
0
,
0
},
{
0
,
1
,
-
1
},
{
1
,
0
,
1
},
{
1
,
1
,
0
},
{
1
,
-
1
,
2
},
{
-
1
,
1
,
-
2
},
{
-
1
,
-
1
,
0
},
}
for
_
,
tt
:=
range
cases
{
t
.
Run
(
strconv
.
Itoa
(
tt
.
expected
),
func
(
t
*
testing
.
T
)
{
sub
,
ok
:=
funcs
[
"sub"
]
.
(
func
(
int
,
int
)
int
)
if
!
ok
{
t
.
Fatal
(
"sub is not a function"
)
}
if
n
:=
sub
(
tt
.
a
,
tt
.
b
);
n
!=
tt
.
expected
{
t
.
Errorf
(
"expected %d, got %d"
,
tt
.
expected
,
n
)
}
})
}
})
}
func
TestNamed
(
t
*
testing
.
T
)
{
func
TestNamed
(
t
*
testing
.
T
)
{
f
,
err
:=
os
.
Open
(
filepath
.
Join
(
"testdata"
,
"templates.jsonl"
))
f
,
err
:=
os
.
Open
(
filepath
.
Join
(
"testdata"
,
"templates.jsonl"
))
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -89,77 +182,86 @@ func TestParse(t *testing.T) {
...
@@ -89,77 +182,86 @@ func TestParse(t *testing.T) {
}
}
func
TestExecuteWithMessages
(
t
*
testing
.
T
)
{
func
TestExecuteWithMessages
(
t
*
testing
.
T
)
{
type
template
struct
{
name
string
template
string
}
cases
:=
[]
struct
{
cases
:=
[]
struct
{
templates
[]
string
name
string
templates
[]
template
values
Values
values
Values
expected
string
expected
string
}{
}{
{
{
[]
string
{
"mistral"
,
`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `
,
[]
template
{
`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
,
{
"no response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `
},
`{{- range .Messages }}
{
"response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }}
{
"messages"
,
`{{- range .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}
{{- end }}`
,
{{- end }}`
}
,
},
},
Values
{
Values
{
Messages
:
[]
api
.
Message
{
Messages
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"Hello friend!"
},
{
Role
:
"user"
,
Content
:
"Hello friend!"
},
{
Role
:
"assistant"
,
Content
:
"Hello human!"
},
{
Role
:
"assistant"
,
Content
:
"Hello human!"
},
{
Role
:
"user"
,
Content
:
"
Yay!
"
},
{
Role
:
"user"
,
Content
:
"
What is your name?
"
},
},
},
},
},
`[INST] Hello friend![/INST] Hello human![INST]
Yay!
[/INST] `
,
`[INST] Hello friend![/INST] Hello human![INST]
What is your name?
[/INST] `
,
},
},
{
{
[]
string
{
"mistral system"
,
`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `
,
[]
template
{
`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
,
{
"no response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `
},
`
{
"response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{
"messages"
,
`
{{- range .Messages }}
{{- range .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (
isLastMessage
$.Messages .) $.System }}{{ $.System }}{{
print
"\n\n" }}
{{- if eq .Role "user" }}[INST] {{ if and (
eq (index $.Messages (sub (len
$.Messages
) 1))
.) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}
{{- end }}`
,
{{- end }}`
}
,
},
},
Values
{
Values
{
Messages
:
[]
api
.
Message
{
Messages
:
[]
api
.
Message
{
{
Role
:
"system"
,
Content
:
"You are a helpful assistant!"
},
{
Role
:
"system"
,
Content
:
"You are a helpful assistant!"
},
{
Role
:
"user"
,
Content
:
"Hello friend!"
},
{
Role
:
"user"
,
Content
:
"Hello friend!"
},
{
Role
:
"assistant"
,
Content
:
"Hello human!"
},
{
Role
:
"assistant"
,
Content
:
"Hello human!"
},
{
Role
:
"user"
,
Content
:
"
Yay!
"
},
{
Role
:
"user"
,
Content
:
"
What is your name?
"
},
},
},
},
},
`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
Yay!
[/INST] `
,
What is your name?
[/INST] `
,
},
},
{
{
[]
string
{
"chatml"
,
`{{ if .System }}<|im_start|>system
[]
template
{
// this does not have a "no response" test because it's impossible to render the same output
{
"response"
,
`{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
{{ .Response }}<|im_end|>
`
,
`
}
,
`
{
"messages"
,
`
{{- range .Messages }}
{{- range .Messages }}
{{- if and (eq .Role "user") (
isLastMessage
$.Messages .) $.System }}<|im_start|>system
{{- if and (eq .Role "user") (
eq (index $.Messages (sub (len
$.Messages
) 1))
.) $.System }}<|im_start|>system
{{ $.System }}<|im_end|>{{
print
"\n" }}
{{ $.System }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>{{ .Role }}
{{- end }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>{{
print
"\n" }}
{{ .Content }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>assistant
{{- end }}<|im_start|>assistant
`
,
`
}
,
},
},
Values
{
Values
{
Messages
:
[]
api
.
Message
{
Messages
:
[]
api
.
Message
{
{
Role
:
"system"
,
Content
:
"You are a helpful assistant!"
},
{
Role
:
"system"
,
Content
:
"You are a helpful assistant!"
},
{
Role
:
"user"
,
Content
:
"Hello friend!"
},
{
Role
:
"user"
,
Content
:
"Hello friend!"
},
{
Role
:
"assistant"
,
Content
:
"Hello human!"
},
{
Role
:
"assistant"
,
Content
:
"Hello human!"
},
{
Role
:
"user"
,
Content
:
"
Yay!
"
},
{
Role
:
"user"
,
Content
:
"
What is your name?
"
},
},
},
},
},
`<|im_start|>user
`<|im_start|>user
...
@@ -169,23 +271,25 @@ Hello human!<|im_end|>
...
@@ -169,23 +271,25 @@ Hello human!<|im_end|>
<|im_start|>system
<|im_start|>system
You are a helpful assistant!<|im_end|>
You are a helpful assistant!<|im_end|>
<|im_start|>user
<|im_start|>user
Yay!
<|im_end|>
What is your name?
<|im_end|>
<|im_start|>assistant
<|im_start|>assistant
`
,
`
,
},
},
{
{
[]
string
{
"moondream"
,
`{{ if .Prompt }}Question: {{ .Prompt }}
[]
template
{
// this does not have a "no response" test because it's impossible to render the same output
{
"response"
,
`{{ if .Prompt }}Question: {{ .Prompt }}
{{ end }}Answer: {{ .Response }}
{{ end }}Answer: {{ .Response }}
`
,
`
}
,
`
{
"messages"
,
`
{{- range .Messages }}
{{- range .Messages }}
{{- if eq .Role "user" }}Question: {{ .Content }}{{
print
"\n\n" }}
{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{
print
"\n\n" }}
{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
{{- end }}
{{- end }}
{{- end }}Answer: `
,
{{- end }}Answer: `
}
,
},
},
Values
{
Values
{
Messages
:
[]
api
.
Message
{
Messages
:
[]
api
.
Message
{
...
@@ -211,10 +315,10 @@ Answer: `,
...
@@ -211,10 +315,10 @@ Answer: `,
}
}
for
_
,
tt
:=
range
cases
{
for
_
,
tt
:=
range
cases
{
t
.
Run
(
""
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
for
_
,
t
mpl
:=
range
tt
.
templates
{
for
_
,
t
tt
:=
range
tt
.
templates
{
t
.
Run
(
""
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
ttt
.
name
,
func
(
t
*
testing
.
T
)
{
tmpl
,
err
:=
Parse
(
t
mpl
)
tmpl
,
err
:=
Parse
(
t
tt
.
template
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Fatal
(
err
)
t
.
Fatal
(
err
)
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment