Commit d290e875 authored by Michael Yang's avatar Michael Yang
Browse files

add suffix support to generate endpoint

this change is triggered by the presence of "suffix", particularly
useful for code completion tasks
parent 987dbab0
...@@ -47,6 +47,9 @@ type GenerateRequest struct { ...@@ -47,6 +47,9 @@ type GenerateRequest struct {
// Prompt is the textual prompt to send to the model. // Prompt is the textual prompt to send to the model.
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
// Suffix is the text that comes after the inserted text.
Suffix string `json:"suffix"`
// System overrides the model's default system message/prompt. // System overrides the model's default system message/prompt.
System string `json:"system"` System string `json:"system"`
......
...@@ -34,13 +34,19 @@ import ( ...@@ -34,13 +34,19 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
var errCapabilityCompletion = errors.New("completion") var (
errCapabilities = errors.New("does not support")
errCapabilityCompletion = errors.New("completion")
errCapabilityTools = errors.New("tools")
errCapabilityInsert = errors.New("insert")
)
type Capability string type Capability string
const ( const (
CapabilityCompletion = Capability("completion") CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools") CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
) )
type registryOptions struct { type registryOptions struct {
...@@ -93,7 +99,12 @@ func (m *Model) CheckCapabilities(caps ...Capability) error { ...@@ -93,7 +99,12 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
} }
case CapabilityTools: case CapabilityTools:
if !slices.Contains(m.Template.Vars(), "tools") { if !slices.Contains(m.Template.Vars(), "tools") {
errs = append(errs, errors.New("tools")) errs = append(errs, errCapabilityTools)
}
case CapabilityInsert:
vars := m.Template.Vars()
if !slices.Contains(vars, "suffix") {
errs = append(errs, errCapabilityInsert)
} }
default: default:
slog.Error("unknown capability", "capability", cap) slog.Error("unknown capability", "capability", cap)
...@@ -102,7 +113,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error { ...@@ -102,7 +113,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
} }
if err := errors.Join(errs...); err != nil { if err := errors.Join(errs...); err != nil {
return fmt.Errorf("does not support %w", errors.Join(errs...)) return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
} }
return nil return nil
......
...@@ -122,6 +122,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -122,6 +122,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
caps := []Capability{CapabilityCompletion} caps := []Capability{CapabilityCompletion}
if req.Suffix != "" {
caps = append(caps, CapabilityInsert)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
...@@ -150,19 +154,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -150,19 +154,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
prompt := req.Prompt prompt := req.Prompt
if !req.Raw { if !req.Raw {
var msgs []api.Message
if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if m.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
}
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
tmpl := m.Template tmpl := m.Template
if req.Template != "" { if req.Template != "" {
tmpl, err = template.Parse(req.Template) tmpl, err = template.Parse(req.Template)
...@@ -183,7 +174,26 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -183,7 +174,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
b.WriteString(s) b.WriteString(s)
} }
if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil { var values template.Values
if req.Suffix != "" {
values.Prompt = prompt
values.Suffix = req.Suffix
} else {
var msgs []api.Message
if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if m.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
}
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
}
if err := tmpl.Execute(&b, values); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
...@@ -1394,7 +1404,7 @@ func (s *Server) ChatHandler(c *gin.Context) { ...@@ -1394,7 +1404,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
func handleScheduleError(c *gin.Context, name string, err error) { func handleScheduleError(c *gin.Context, name string, err error) {
switch { switch {
case errors.Is(err, errRequired): case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, context.Canceled): case errors.Is(err, context.Canceled):
c.JSON(499, gin.H{"error": "request canceled"}) c.JSON(499, gin.H{"error": "request canceled"})
......
...@@ -73,6 +73,8 @@ func TestGenerateChat(t *testing.T) { ...@@ -73,6 +73,8 @@ func TestGenerateChat(t *testing.T) {
getCpuFn: gpu.GetCPUInfo, getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond, reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) { loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
// add 10ms delay to simulate loading
time.Sleep(10 * time.Millisecond)
req.successCh <- &runnerRef{ req.successCh <- &runnerRef{
llama: &mock, llama: &mock,
} }
...@@ -83,7 +85,7 @@ func TestGenerateChat(t *testing.T) { ...@@ -83,7 +85,7 @@ func TestGenerateChat(t *testing.T) {
go s.sched.Run(context.TODO()) go s.sched.Run(context.TODO())
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test", Model: "test",
Modelfile: fmt.Sprintf(`FROM %s Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """ TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }} {{- if .System }}System: {{ .System }} {{ end }}
...@@ -141,9 +143,9 @@ func TestGenerateChat(t *testing.T) { ...@@ -141,9 +143,9 @@ func TestGenerateChat(t *testing.T) {
} }
}) })
t.Run("missing capabilities", func(t *testing.T) { t.Run("missing capabilities chat", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "bert", Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{ Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"general.architecture": "bert", "general.architecture": "bert",
"bert.pooling_type": uint32(0), "bert.pooling_type": uint32(0),
...@@ -243,7 +245,7 @@ func TestGenerateChat(t *testing.T) { ...@@ -243,7 +245,7 @@ func TestGenerateChat(t *testing.T) {
} }
if actual.TotalDuration == 0 { if actual.TotalDuration == 0 {
t.Errorf("expected load duration > 0, got 0") t.Errorf("expected total duration > 0, got 0")
} }
} }
...@@ -379,7 +381,7 @@ func TestGenerate(t *testing.T) { ...@@ -379,7 +381,7 @@ func TestGenerate(t *testing.T) {
go s.sched.Run(context.TODO()) go s.sched.Run(context.TODO())
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test", Model: "test",
Modelfile: fmt.Sprintf(`FROM %s Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """ TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }} {{- if .System }}System: {{ .System }} {{ end }}
...@@ -437,9 +439,9 @@ func TestGenerate(t *testing.T) { ...@@ -437,9 +439,9 @@ func TestGenerate(t *testing.T) {
} }
}) })
t.Run("missing capabilities", func(t *testing.T) { t.Run("missing capabilities generate", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "bert", Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{ Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"general.architecture": "bert", "general.architecture": "bert",
"bert.pooling_type": uint32(0), "bert.pooling_type": uint32(0),
...@@ -464,6 +466,22 @@ func TestGenerate(t *testing.T) { ...@@ -464,6 +466,22 @@ func TestGenerate(t *testing.T) {
} }
}) })
t.Run("missing capabilities suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("load model", func(t *testing.T) { t.Run("load model", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test", Model: "test",
...@@ -540,7 +558,7 @@ func TestGenerate(t *testing.T) { ...@@ -540,7 +558,7 @@ func TestGenerate(t *testing.T) {
} }
if actual.TotalDuration == 0 { if actual.TotalDuration == 0 {
t.Errorf("expected load duration > 0, got 0") t.Errorf("expected total duration > 0, got 0")
} }
} }
...@@ -632,6 +650,49 @@ func TestGenerate(t *testing.T) { ...@@ -632,6 +650,49 @@ func TestGenerate(t *testing.T) {
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!") checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
}) })
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-suffix",
Modelfile: `FROM test
TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}"""`,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("prompt with suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("prompt without suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix",
Prompt: "def add(",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("raw", func(t *testing.T) { t.Run("raw", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system", Model: "test-system",
......
...@@ -151,6 +151,8 @@ func (t *Template) Vars() []string { ...@@ -151,6 +151,8 @@ func (t *Template) Vars() []string {
type Values struct { type Values struct {
Messages []api.Message Messages []api.Message
Tools []api.Tool Tools []api.Tool
Prompt string
Suffix string
// forceLegacy is a flag used to test compatibility with legacy templates // forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool forceLegacy bool
...@@ -204,7 +206,13 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template { ...@@ -204,7 +206,13 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
func (t *Template) Execute(w io.Writer, v Values) error { func (t *Template) Execute(w io.Writer, v Values) error {
system, messages := collate(v.Messages) system, messages := collate(v.Messages)
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { if v.Prompt != "" && v.Suffix != "" {
return t.Template.Execute(w, map[string]any{
"Prompt": v.Prompt,
"Suffix": v.Suffix,
"Response": "",
})
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{ return t.Template.Execute(w, map[string]any{
"System": system, "System": system,
"Messages": messages, "Messages": messages,
......
...@@ -359,3 +359,38 @@ Answer: `, ...@@ -359,3 +359,38 @@ Answer: `,
}) })
} }
} }
func TestExecuteWithSuffix(t *testing.T) {
tmpl, err := Parse(`{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}`)
if err != nil {
t.Fatal(err)
}
cases := []struct {
name string
values Values
expect string
}{
{
"message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
},
{
"prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "<PRE> def add( <SUF>return x <MID>",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
var b bytes.Buffer
if err := tmpl.Execute(&b, tt.values); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment