prompt.go 4.05 KB
Newer Older
1
2
3
package server

import (
Michael Yang's avatar
Michael Yang committed
4
5
	"bytes"
	"context"
6
7
8
	"encoding/binary"
	"errors"
	"fmt"
9
	"log/slog"
10
	"strings"
11

12
	"github.com/ollama/ollama/api"
Michael Yang's avatar
Michael Yang committed
13
	"github.com/ollama/ollama/llm"
14
	"github.com/ollama/ollama/model/models/mllama"
Michael Yang's avatar
Michael Yang committed
15
	"github.com/ollama/ollama/template"
16
17
)

Michael Yang's avatar
Michael Yang committed
18
19
type tokenizeFunc func(context.Context, string) ([]int, error)

20
21
var errTooManyImages = errors.New("vision model only supports a single image per message")

Michael Yang's avatar
Michael Yang committed
22
23
24
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
Michael Yang's avatar
tools  
Michael Yang committed
25
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
Michael Yang's avatar
Michael Yang committed
26
	var system []api.Message
27
28

	isMllama := checkMllamaModelFamily(m)
29
	isGemma3 := checkGemma3ModelFamily(m)
30

31
32
33
34
35
36
37
38
39
40
	var imageNumTokens int
	// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
	if isMllama {
		// Our mllama implementation packs all of the embeddings into a single token
		imageNumTokens = 1
	} else {
		// Clip images are represented as 768 tokens, each an embedding
		imageNumTokens = 768
	}

Michael Yang's avatar
Michael Yang committed
41
	n := len(msgs) - 1
Michael Yang's avatar
Michael Yang committed
42
	// in reverse, find all messages that fit into context window
43
	for i := n; i >= 0; i-- {
44
		if (isMllama || isGemma3) && len(msgs[i].Images) > 1 {
45
46
47
48
49
50
51
52
			return "", nil, errTooManyImages
		}

		// always include the last message
		if i == n {
			continue
		}

Michael Yang's avatar
Michael Yang committed
53
54
55
56
57
58
59
		system = make([]api.Message, 0)
		for j := range i {
			if msgs[j].Role == "system" {
				system = append(system, msgs[j])
			}
		}

Michael Yang's avatar
Michael Yang committed
60
		var b bytes.Buffer
Michael Yang's avatar
tools  
Michael Yang committed
61
		if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
Michael Yang's avatar
Michael Yang committed
62
			return "", nil, err
63
64
		}

Michael Yang's avatar
Michael Yang committed
65
		s, err := tokenize(ctx, b.String())
66
		if err != nil {
Michael Yang's avatar
Michael Yang committed
67
			return "", nil, err
68
69
		}

70
		ctxLen := len(s)
Michael Yang's avatar
Michael Yang committed
71
		if m.ProjectorPaths != nil {
Michael Yang's avatar
Michael Yang committed
72
			for _, m := range msgs[i:] {
73
				ctxLen += imageNumTokens * len(m.Images)
Michael Yang's avatar
Michael Yang committed
74
			}
75
76
		}

77
		if ctxLen > opts.NumCtx {
Michael Yang's avatar
Michael Yang committed
78
			slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
79
			break
Michael Yang's avatar
Michael Yang committed
80
81
		} else {
			n = i
82
		}
Michael Yang's avatar
Michael Yang committed
83
	}
84

85
86
	currMsgIdx := n

87
88
89
90
91
92
93
94
95
	for cnt, msg := range msgs[currMsgIdx:] {
		prefix := ""
		imgPrompt := ""
		prompt := msg.Content

		for _, i := range msg.Images {
			var imgData llm.ImageData

			if isMllama {
96
				if len(m.ProjectorPaths) == 0 {
Jesse Gross's avatar
Jesse Gross committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
					imgData = llm.ImageData{
						ID:   len(images),
						Data: i,
					}
				} else {
					data, opts, err := mllama.Preprocess(bytes.NewReader(i))
					if err != nil {
						return "", nil, err
					}

					buf := new(bytes.Buffer)
					err = binary.Write(buf, binary.LittleEndian, data)
					if err != nil {
						return "", nil, err
					}

					ar, ok := opts["aspectRatioIndex"].(int)
					if !ok {
						return "", nil, fmt.Errorf("missing aspect ratio for image")
					}

					imgData = llm.ImageData{
						ID:            len(images),
						Data:          buf.Bytes(),
						AspectRatioID: ar,
					}
123
				}
124
125
126
				imgPrompt = "<|image|>"
			} else {
				imgData = llm.ImageData{
127
128
129
					ID:   len(images),
					Data: i,
				}
130
			}
131

132
133
134
135
136
			imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
			if !strings.Contains(prompt, "[img]") {
				prefix += imgTag
			} else {
				prompt = strings.Replace(prompt, "[img]", imgTag, 1)
137
			}
138
139

			images = append(images, imgData)
140
		}
141
		msgs[currMsgIdx+cnt].Content = prefix + imgPrompt + prompt
142
143
	}

Michael Yang's avatar
Michael Yang committed
144
	// truncate any messages that do not fit into the context window
Michael Yang's avatar
Michael Yang committed
145
	var b bytes.Buffer
146
	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
Michael Yang's avatar
Michael Yang committed
147
		return "", nil, err
148
149
	}

150
151
152
153
154
155
156
	return b.String(), images, nil
}

func checkMllamaModelFamily(m *Model) bool {
	for _, arch := range m.Config.ModelFamilies {
		if arch == "mllama" {
			return true
157
158
		}
	}
159
	return false
160
}
161
162
163
164
165
166
167
168
169

func checkGemma3ModelFamily(m *Model) bool {
	for _, arch := range m.Config.ModelFamilies {
		if arch == "gemma3" {
			return true
		}
	}
	return false
}