You need to sign in or sign up before continuing.
Unverified Commit d9e60f63 authored by Patrick Devine's avatar Patrick Devine Committed by GitHub
Browse files

add image support to the chat api (#1490)

parent 4251b342
...@@ -57,8 +57,9 @@ type ChatRequest struct { ...@@ -57,8 +57,9 @@ type ChatRequest struct {
} }
type Message struct { type Message struct {
Role string `json:"role"` // one of ["system", "user", "assistant"] Role string `json:"role"` // one of ["system", "user", "assistant"]
Content string `json:"content"` Content string `json:"content"`
Images []ImageData `json:"images, omitempty"`
} }
type ChatResponse struct { type ChatResponse struct {
......
...@@ -86,9 +86,10 @@ func (m *Model) Prompt(p PromptVars) (string, error) { ...@@ -86,9 +86,10 @@ func (m *Model) Prompt(p PromptVars) (string, error) {
return prompt.String(), nil return prompt.String(), nil
} }
func (m *Model) ChatPrompt(msgs []api.Message) (string, error) { func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
// build the prompt from the list of messages // build the prompt from the list of messages
var prompt strings.Builder var prompt strings.Builder
var currentImages []api.ImageData
currentVars := PromptVars{ currentVars := PromptVars{
First: true, First: true,
} }
...@@ -108,35 +109,36 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) { ...@@ -108,35 +109,36 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
case "system": case "system":
if currentVars.System != "" { if currentVars.System != "" {
if err := writePrompt(); err != nil { if err := writePrompt(); err != nil {
return "", err return "", nil, err
} }
} }
currentVars.System = msg.Content currentVars.System = msg.Content
case "user": case "user":
if currentVars.Prompt != "" { if currentVars.Prompt != "" {
if err := writePrompt(); err != nil { if err := writePrompt(); err != nil {
return "", err return "", nil, err
} }
} }
currentVars.Prompt = msg.Content currentVars.Prompt = msg.Content
currentImages = msg.Images
case "assistant": case "assistant":
currentVars.Response = msg.Content currentVars.Response = msg.Content
if err := writePrompt(); err != nil { if err := writePrompt(); err != nil {
return "", err return "", nil, err
} }
default: default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
} }
} }
// Append the last set of vars if they are non-empty // Append the last set of vars if they are non-empty
if currentVars.Prompt != "" || currentVars.System != "" { if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil { if err := writePrompt(); err != nil {
return "", err return "", nil, err
} }
} }
return prompt.String(), nil return prompt.String(), currentImages, nil
} }
type ManifestV2 struct { type ManifestV2 struct {
......
...@@ -994,7 +994,7 @@ func ChatHandler(c *gin.Context) { ...@@ -994,7 +994,7 @@ func ChatHandler(c *gin.Context) {
checkpointLoaded := time.Now() checkpointLoaded := time.Now()
prompt, err := model.ChatPrompt(req.Messages) prompt, images, err := model.ChatPrompt(req.Messages)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
...@@ -1037,6 +1037,7 @@ func ChatHandler(c *gin.Context) { ...@@ -1037,6 +1037,7 @@ func ChatHandler(c *gin.Context) {
Format: req.Format, Format: req.Format,
CheckpointStart: checkpointStart, CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded, CheckpointLoaded: checkpointLoaded,
Images: images,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment