Unverified Commit 0e19476b authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

prepend image tags (#2789)

instead of appending image tags, prepend them - this generally produces better results
parent fa2f2b35
...@@ -121,13 +121,15 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str ...@@ -121,13 +121,15 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
p = prompt{} p = prompt{}
} }
p.Prompt = msg.Content var sb strings.Builder
for range msg.Images { for range msg.Images {
p.Prompt += fmt.Sprintf(" [img-%d]", imgId) fmt.Fprintf(&sb, "[img-%d] ", imgId)
p.images = append(p.images, imgId) p.images = append(p.images, imgId)
imgId += 1 imgId += 1
} }
sb.WriteString(msg.Content)
p.Prompt = sb.String()
case "assistant": case "assistant":
if p.Response != "" { if p.Response != "" {
prompts = append(prompts, p) prompts = append(prompts, p)
......
...@@ -155,7 +155,7 @@ func TestChatPrompt(t *testing.T) { ...@@ -155,7 +155,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}}, {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
}, },
window: 1024, window: 1024,
want: "You are a Wizard. Hello [img-0]", want: "You are a Wizard. [img-0] Hello",
}, },
{ {
name: "images truncated", name: "images truncated",
...@@ -165,7 +165,7 @@ func TestChatPrompt(t *testing.T) { ...@@ -165,7 +165,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}}, {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
}, },
window: 1024, window: 1024,
want: "You are a Wizard. Hello [img-1]", want: "You are a Wizard. [img-0] [img-1] Hello",
}, },
{ {
name: "empty list", name: "empty list",
...@@ -198,7 +198,7 @@ func TestChatPrompt(t *testing.T) { ...@@ -198,7 +198,7 @@ func TestChatPrompt(t *testing.T) {
} }
if got != tc.want { if got != tc.want {
t.Errorf("got = %v, want %v", got, tc.want) t.Errorf("got: %q, want: %q", got, tc.want)
} }
}) })
} }
......
...@@ -250,28 +250,29 @@ func GenerateHandler(c *gin.Context) { ...@@ -250,28 +250,29 @@ func GenerateHandler(c *gin.Context) {
slog.Debug("generate handler", "system", req.System) slog.Debug("generate handler", "system", req.System)
var sb strings.Builder var sb strings.Builder
if req.Context != nil { for i := range req.Images {
prev, err := loaded.runner.Decode(c.Request.Context(), req.Context) fmt.Fprintf(&sb, "[img-%d] ", i)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
} }
sb.WriteString(prev) sb.WriteString(req.Prompt)
}
// write image tags p, err := Prompt(req.Template, req.System, sb.String(), "", true)
// TODO: limit the number of images to fit in the context similar to the chat endpoint if err != nil {
for i := range req.Images { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
req.Prompt += fmt.Sprintf(" [img-%d]", i) return
} }
p, err := Prompt(req.Template, req.System, req.Prompt, "", true) sb.Reset()
if req.Context != nil {
prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
sb.WriteString(prev)
}
sb.WriteString(p) sb.WriteString(p)
prompt = sb.String() prompt = sb.String()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment