Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
cd0853f2
Unverified
Commit
cd0853f2
authored
Jul 16, 2024
by
Michael Yang
Committed by
GitHub
Jul 16, 2024
Browse files
Merge pull request #5207 from ollama/mxyng/suffix
add insert support to generate endpoint
parents
97c20ede
d290e875
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
155 additions
and
27 deletions
+155
-27
api/types.go
api/types.go
+3
-0
server/images.go
server/images.go
+14
-3
server/routes.go
server/routes.go
+25
-15
server/routes_generate_test.go
server/routes_generate_test.go
+69
-8
template/template.go
template/template.go
+9
-1
template/template_test.go
template/template_test.go
+35
-0
No files found.
api/types.go
View file @
cd0853f2
...
@@ -47,6 +47,9 @@ type GenerateRequest struct {
...
@@ -47,6 +47,9 @@ type GenerateRequest struct {
// Prompt is the textual prompt to send to the model.
// Prompt is the textual prompt to send to the model.
Prompt
string
`json:"prompt"`
Prompt
string
`json:"prompt"`
// Suffix is the text that comes after the inserted text.
Suffix
string
`json:"suffix"`
// System overrides the model's default system message/prompt.
// System overrides the model's default system message/prompt.
System
string
`json:"system"`
System
string
`json:"system"`
...
...
server/images.go
View file @
cd0853f2
...
@@ -34,13 +34,19 @@ import (
...
@@ -34,13 +34,19 @@ import (
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/version"
)
)
var
errCapabilityCompletion
=
errors
.
New
(
"completion"
)
var
(
errCapabilities
=
errors
.
New
(
"does not support"
)
errCapabilityCompletion
=
errors
.
New
(
"completion"
)
errCapabilityTools
=
errors
.
New
(
"tools"
)
errCapabilityInsert
=
errors
.
New
(
"insert"
)
)
type
Capability
string
type
Capability
string
const
(
const
(
CapabilityCompletion
=
Capability
(
"completion"
)
CapabilityCompletion
=
Capability
(
"completion"
)
CapabilityTools
=
Capability
(
"tools"
)
CapabilityTools
=
Capability
(
"tools"
)
CapabilityInsert
=
Capability
(
"insert"
)
)
)
type
registryOptions
struct
{
type
registryOptions
struct
{
...
@@ -93,7 +99,12 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
...
@@ -93,7 +99,12 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
}
}
case
CapabilityTools
:
case
CapabilityTools
:
if
!
slices
.
Contains
(
m
.
Template
.
Vars
(),
"tools"
)
{
if
!
slices
.
Contains
(
m
.
Template
.
Vars
(),
"tools"
)
{
errs
=
append
(
errs
,
errors
.
New
(
"tools"
))
errs
=
append
(
errs
,
errCapabilityTools
)
}
case
CapabilityInsert
:
vars
:=
m
.
Template
.
Vars
()
if
!
slices
.
Contains
(
vars
,
"suffix"
)
{
errs
=
append
(
errs
,
errCapabilityInsert
)
}
}
default
:
default
:
slog
.
Error
(
"unknown capability"
,
"capability"
,
cap
)
slog
.
Error
(
"unknown capability"
,
"capability"
,
cap
)
...
@@ -102,7 +113,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
...
@@ -102,7 +113,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
}
}
if
err
:=
errors
.
Join
(
errs
...
);
err
!=
nil
{
if
err
:=
errors
.
Join
(
errs
...
);
err
!=
nil
{
return
fmt
.
Errorf
(
"
does not support %w"
,
errors
.
Join
(
errs
...
))
return
fmt
.
Errorf
(
"
%w %w"
,
errCapabilities
,
errors
.
Join
(
errs
...
))
}
}
return
nil
return
nil
...
...
server/routes.go
View file @
cd0853f2
...
@@ -122,6 +122,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
...
@@ -122,6 +122,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
}
caps
:=
[]
Capability
{
CapabilityCompletion
}
caps
:=
[]
Capability
{
CapabilityCompletion
}
if
req
.
Suffix
!=
""
{
caps
=
append
(
caps
,
CapabilityInsert
)
}
r
,
m
,
opts
,
err
:=
s
.
scheduleRunner
(
c
.
Request
.
Context
(),
req
.
Model
,
caps
,
req
.
Options
,
req
.
KeepAlive
)
r
,
m
,
opts
,
err
:=
s
.
scheduleRunner
(
c
.
Request
.
Context
(),
req
.
Model
,
caps
,
req
.
Options
,
req
.
KeepAlive
)
if
errors
.
Is
(
err
,
errCapabilityCompletion
)
{
if
errors
.
Is
(
err
,
errCapabilityCompletion
)
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"%q does not support generate"
,
req
.
Model
)})
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"%q does not support generate"
,
req
.
Model
)})
...
@@ -150,19 +154,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
...
@@ -150,19 +154,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
prompt
:=
req
.
Prompt
prompt
:=
req
.
Prompt
if
!
req
.
Raw
{
if
!
req
.
Raw
{
var
msgs
[]
api
.
Message
if
req
.
System
!=
""
{
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"system"
,
Content
:
req
.
System
})
}
else
if
m
.
System
!=
""
{
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"system"
,
Content
:
m
.
System
})
}
for
_
,
i
:=
range
images
{
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"user"
,
Content
:
fmt
.
Sprintf
(
"[img-%d]"
,
i
.
ID
)})
}
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"user"
,
Content
:
req
.
Prompt
})
tmpl
:=
m
.
Template
tmpl
:=
m
.
Template
if
req
.
Template
!=
""
{
if
req
.
Template
!=
""
{
tmpl
,
err
=
template
.
Parse
(
req
.
Template
)
tmpl
,
err
=
template
.
Parse
(
req
.
Template
)
...
@@ -183,7 +174,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
...
@@ -183,7 +174,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
b
.
WriteString
(
s
)
b
.
WriteString
(
s
)
}
}
if
err
:=
tmpl
.
Execute
(
&
b
,
template
.
Values
{
Messages
:
msgs
});
err
!=
nil
{
var
values
template
.
Values
if
req
.
Suffix
!=
""
{
values
.
Prompt
=
prompt
values
.
Suffix
=
req
.
Suffix
}
else
{
var
msgs
[]
api
.
Message
if
req
.
System
!=
""
{
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"system"
,
Content
:
req
.
System
})
}
else
if
m
.
System
!=
""
{
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"system"
,
Content
:
m
.
System
})
}
for
_
,
i
:=
range
images
{
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
"user"
,
Content
:
fmt
.
Sprintf
(
"[img-%d]"
,
i
.
ID
)})
}
values
.
Messages
=
append
(
msgs
,
api
.
Message
{
Role
:
"user"
,
Content
:
req
.
Prompt
})
}
if
err
:=
tmpl
.
Execute
(
&
b
,
values
);
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
return
}
}
...
@@ -1394,7 +1404,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
...
@@ -1394,7 +1404,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
func
handleScheduleError
(
c
*
gin
.
Context
,
name
string
,
err
error
)
{
func
handleScheduleError
(
c
*
gin
.
Context
,
name
string
,
err
error
)
{
switch
{
switch
{
case
errors
.
Is
(
err
,
errRequired
)
:
case
errors
.
Is
(
err
,
errCapabilities
),
errors
.
Is
(
err
,
errRequired
)
:
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
case
errors
.
Is
(
err
,
context
.
Canceled
)
:
case
errors
.
Is
(
err
,
context
.
Canceled
)
:
c
.
JSON
(
499
,
gin
.
H
{
"error"
:
"request canceled"
})
c
.
JSON
(
499
,
gin
.
H
{
"error"
:
"request canceled"
})
...
...
server/routes_generate_test.go
View file @
cd0853f2
...
@@ -73,6 +73,8 @@ func TestGenerateChat(t *testing.T) {
...
@@ -73,6 +73,8 @@ func TestGenerateChat(t *testing.T) {
getCpuFn
:
gpu
.
GetCPUInfo
,
getCpuFn
:
gpu
.
GetCPUInfo
,
reschedDelay
:
250
*
time
.
Millisecond
,
reschedDelay
:
250
*
time
.
Millisecond
,
loadFn
:
func
(
req
*
LlmRequest
,
ggml
*
llm
.
GGML
,
gpus
gpu
.
GpuInfoList
,
numParallel
int
)
{
loadFn
:
func
(
req
*
LlmRequest
,
ggml
*
llm
.
GGML
,
gpus
gpu
.
GpuInfoList
,
numParallel
int
)
{
// add 10ms delay to simulate loading
time
.
Sleep
(
10
*
time
.
Millisecond
)
req
.
successCh
<-
&
runnerRef
{
req
.
successCh
<-
&
runnerRef
{
llama
:
&
mock
,
llama
:
&
mock
,
}
}
...
@@ -83,7 +85,7 @@ func TestGenerateChat(t *testing.T) {
...
@@ -83,7 +85,7 @@ func TestGenerateChat(t *testing.T) {
go
s
.
sched
.
Run
(
context
.
TODO
())
go
s
.
sched
.
Run
(
context
.
TODO
())
w
:=
createRequest
(
t
,
s
.
CreateModelHandler
,
api
.
CreateRequest
{
w
:=
createRequest
(
t
,
s
.
CreateModelHandler
,
api
.
CreateRequest
{
Name
:
"test"
,
Model
:
"test"
,
Modelfile
:
fmt
.
Sprintf
(
`FROM %s
Modelfile
:
fmt
.
Sprintf
(
`FROM %s
TEMPLATE """
TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }}
{{- if .System }}System: {{ .System }} {{ end }}
...
@@ -141,9 +143,9 @@ func TestGenerateChat(t *testing.T) {
...
@@ -141,9 +143,9 @@ func TestGenerateChat(t *testing.T) {
}
}
})
})
t
.
Run
(
"missing capabilities"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"missing capabilities
chat
"
,
func
(
t
*
testing
.
T
)
{
w
:=
createRequest
(
t
,
s
.
CreateModelHandler
,
api
.
CreateRequest
{
w
:=
createRequest
(
t
,
s
.
CreateModelHandler
,
api
.
CreateRequest
{
Name
:
"bert"
,
Model
:
"bert"
,
Modelfile
:
fmt
.
Sprintf
(
"FROM %s"
,
createBinFile
(
t
,
llm
.
KV
{
Modelfile
:
fmt
.
Sprintf
(
"FROM %s"
,
createBinFile
(
t
,
llm
.
KV
{
"general.architecture"
:
"bert"
,
"general.architecture"
:
"bert"
,
"bert.pooling_type"
:
uint32
(
0
),
"bert.pooling_type"
:
uint32
(
0
),
...
@@ -243,7 +245,7 @@ func TestGenerateChat(t *testing.T) {
...
@@ -243,7 +245,7 @@ func TestGenerateChat(t *testing.T) {
}
}
if
actual
.
TotalDuration
==
0
{
if
actual
.
TotalDuration
==
0
{
t
.
Errorf
(
"expected
load
duration > 0, got 0"
)
t
.
Errorf
(
"expected
total
duration > 0, got 0"
)
}
}
}
}
...
@@ -379,7 +381,7 @@ func TestGenerate(t *testing.T) {
...
@@ -379,7 +381,7 @@ func TestGenerate(t *testing.T) {
go
s
.
sched
.
Run
(
context
.
TODO
())
go
s
.
sched
.
Run
(
context
.
TODO
())
w
:=
createRequest
(
t
,
s
.
CreateModelHandler
,
api
.
CreateRequest
{
w
:=
createRequest
(
t
,
s
.
CreateModelHandler
,
api
.
CreateRequest
{
Name
:
"test"
,
Model
:
"test"
,
Modelfile
:
fmt
.
Sprintf
(
`FROM %s
Modelfile
:
fmt
.
Sprintf
(
`FROM %s
TEMPLATE """
TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }}
{{- if .System }}System: {{ .System }} {{ end }}
...
@@ -437,9 +439,9 @@ func TestGenerate(t *testing.T) {
...
@@ -437,9 +439,9 @@ func TestGenerate(t *testing.T) {
}
}
})
})
t
.
Run
(
"missing capabilities"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"missing capabilities
generate
"
,
func
(
t
*
testing
.
T
)
{
w
:=
createRequest
(
t
,
s
.
CreateModelHandler
,
api
.
CreateRequest
{
w
:=
createRequest
(
t
,
s
.
CreateModelHandler
,
api
.
CreateRequest
{
Name
:
"bert"
,
Model
:
"bert"
,
Modelfile
:
fmt
.
Sprintf
(
"FROM %s"
,
createBinFile
(
t
,
llm
.
KV
{
Modelfile
:
fmt
.
Sprintf
(
"FROM %s"
,
createBinFile
(
t
,
llm
.
KV
{
"general.architecture"
:
"bert"
,
"general.architecture"
:
"bert"
,
"bert.pooling_type"
:
uint32
(
0
),
"bert.pooling_type"
:
uint32
(
0
),
...
@@ -464,6 +466,22 @@ func TestGenerate(t *testing.T) {
...
@@ -464,6 +466,22 @@ func TestGenerate(t *testing.T) {
}
}
})
})
t
.
Run
(
"missing capabilities suffix"
,
func
(
t
*
testing
.
T
)
{
w
:=
createRequest
(
t
,
s
.
GenerateHandler
,
api
.
GenerateRequest
{
Model
:
"test"
,
Prompt
:
"def add("
,
Suffix
:
" return c"
,
})
if
w
.
Code
!=
http
.
StatusBadRequest
{
t
.
Errorf
(
"expected status 400, got %d"
,
w
.
Code
)
}
if
diff
:=
cmp
.
Diff
(
w
.
Body
.
String
(),
`{"error":"test does not support insert"}`
);
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
t
.
Run
(
"load model"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"load model"
,
func
(
t
*
testing
.
T
)
{
w
:=
createRequest
(
t
,
s
.
GenerateHandler
,
api
.
GenerateRequest
{
w
:=
createRequest
(
t
,
s
.
GenerateHandler
,
api
.
GenerateRequest
{
Model
:
"test"
,
Model
:
"test"
,
...
@@ -540,7 +558,7 @@ func TestGenerate(t *testing.T) {
...
@@ -540,7 +558,7 @@ func TestGenerate(t *testing.T) {
}
}
if
actual
.
TotalDuration
==
0
{
if
actual
.
TotalDuration
==
0
{
t
.
Errorf
(
"expected
load
duration > 0, got 0"
)
t
.
Errorf
(
"expected
total
duration > 0, got 0"
)
}
}
}
}
...
@@ -632,6 +650,49 @@ func TestGenerate(t *testing.T) {
...
@@ -632,6 +650,49 @@ func TestGenerate(t *testing.T) {
checkGenerateResponse
(
t
,
w
.
Body
,
"test-system"
,
"Abra kadabra!"
)
checkGenerateResponse
(
t
,
w
.
Body
,
"test-system"
,
"Abra kadabra!"
)
})
})
w
=
createRequest
(
t
,
s
.
CreateModelHandler
,
api
.
CreateRequest
{
Model
:
"test-suffix"
,
Modelfile
:
`FROM test
TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}"""`
,
})
if
w
.
Code
!=
http
.
StatusOK
{
t
.
Fatalf
(
"expected status 200, got %d"
,
w
.
Code
)
}
t
.
Run
(
"prompt with suffix"
,
func
(
t
*
testing
.
T
)
{
w
:=
createRequest
(
t
,
s
.
GenerateHandler
,
api
.
GenerateRequest
{
Model
:
"test-suffix"
,
Prompt
:
"def add("
,
Suffix
:
" return c"
,
})
if
w
.
Code
!=
http
.
StatusOK
{
t
.
Errorf
(
"expected status 200, got %d"
,
w
.
Code
)
}
if
diff
:=
cmp
.
Diff
(
mock
.
CompletionRequest
.
Prompt
,
"<PRE> def add( <SUF> return c <MID>"
);
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
t
.
Run
(
"prompt without suffix"
,
func
(
t
*
testing
.
T
)
{
w
:=
createRequest
(
t
,
s
.
GenerateHandler
,
api
.
GenerateRequest
{
Model
:
"test-suffix"
,
Prompt
:
"def add("
,
})
if
w
.
Code
!=
http
.
StatusOK
{
t
.
Errorf
(
"expected status 200, got %d"
,
w
.
Code
)
}
if
diff
:=
cmp
.
Diff
(
mock
.
CompletionRequest
.
Prompt
,
"def add("
);
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
t
.
Run
(
"raw"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"raw"
,
func
(
t
*
testing
.
T
)
{
w
:=
createRequest
(
t
,
s
.
GenerateHandler
,
api
.
GenerateRequest
{
w
:=
createRequest
(
t
,
s
.
GenerateHandler
,
api
.
GenerateRequest
{
Model
:
"test-system"
,
Model
:
"test-system"
,
...
...
template/template.go
View file @
cd0853f2
...
@@ -151,6 +151,8 @@ func (t *Template) Vars() []string {
...
@@ -151,6 +151,8 @@ func (t *Template) Vars() []string {
type
Values
struct
{
type
Values
struct
{
Messages
[]
api
.
Message
Messages
[]
api
.
Message
Tools
[]
api
.
Tool
Tools
[]
api
.
Tool
Prompt
string
Suffix
string
// forceLegacy is a flag used to test compatibility with legacy templates
// forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy
bool
forceLegacy
bool
...
@@ -204,7 +206,13 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
...
@@ -204,7 +206,13 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
func
(
t
*
Template
)
Execute
(
w
io
.
Writer
,
v
Values
)
error
{
func
(
t
*
Template
)
Execute
(
w
io
.
Writer
,
v
Values
)
error
{
system
,
messages
:=
collate
(
v
.
Messages
)
system
,
messages
:=
collate
(
v
.
Messages
)
if
!
v
.
forceLegacy
&&
slices
.
Contains
(
t
.
Vars
(),
"messages"
)
{
if
v
.
Prompt
!=
""
&&
v
.
Suffix
!=
""
{
return
t
.
Template
.
Execute
(
w
,
map
[
string
]
any
{
"Prompt"
:
v
.
Prompt
,
"Suffix"
:
v
.
Suffix
,
"Response"
:
""
,
})
}
else
if
!
v
.
forceLegacy
&&
slices
.
Contains
(
t
.
Vars
(),
"messages"
)
{
return
t
.
Template
.
Execute
(
w
,
map
[
string
]
any
{
return
t
.
Template
.
Execute
(
w
,
map
[
string
]
any
{
"System"
:
system
,
"System"
:
system
,
"Messages"
:
messages
,
"Messages"
:
messages
,
...
...
template/template_test.go
View file @
cd0853f2
...
@@ -359,3 +359,38 @@ Answer: `,
...
@@ -359,3 +359,38 @@ Answer: `,
})
})
}
}
}
}
func
TestExecuteWithSuffix
(
t
*
testing
.
T
)
{
tmpl
,
err
:=
Parse
(
`{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}`
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
cases
:=
[]
struct
{
name
string
values
Values
expect
string
}{
{
"message"
,
Values
{
Messages
:
[]
api
.
Message
{{
Role
:
"user"
,
Content
:
"hello"
}}},
"hello"
,
},
{
"prompt suffix"
,
Values
{
Prompt
:
"def add("
,
Suffix
:
"return x"
},
"<PRE> def add( <SUF>return x <MID>"
,
},
}
for
_
,
tt
:=
range
cases
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
var
b
bytes
.
Buffer
if
err
:=
tmpl
.
Execute
(
&
b
,
tt
.
values
);
err
!=
nil
{
t
.
Fatal
(
err
)
}
if
diff
:=
cmp
.
Diff
(
b
.
String
(),
tt
.
expect
);
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment