prompt.go 3.84 KB
Newer Older
1
2
3
package server

import (
Michael Yang's avatar
Michael Yang committed
4
5
	"bytes"
	"context"
6
7
8
	"encoding/binary"
	"errors"
	"fmt"
9
	"log/slog"
10
	"strings"
11

12
	"github.com/ollama/ollama/api"
Michael Yang's avatar
Michael Yang committed
13
	"github.com/ollama/ollama/llm"
14
	"github.com/ollama/ollama/model/models/mllama"
Michael Yang's avatar
Michael Yang committed
15
	"github.com/ollama/ollama/template"
16
17
)

Michael Yang's avatar
Michael Yang committed
18
19
type tokenizeFunc func(context.Context, string) ([]int, error)

20
21
var errTooManyImages = errors.New("vision model only supports a single image per message")

Michael Yang's avatar
Michael Yang committed
22
23
24
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
Michael Yang's avatar
tools  
Michael Yang committed
25
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
Michael Yang's avatar
Michael Yang committed
26
	var system []api.Message
27
28
29

	isMllama := checkMllamaModelFamily(m)

30
31
32
33
34
35
36
37
38
39
	var imageNumTokens int
	// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
	if isMllama {
		// Our mllama implementation packs all of the embeddings into a single token
		imageNumTokens = 1
	} else {
		// Clip images are represented as 768 tokens, each an embedding
		imageNumTokens = 768
	}

Michael Yang's avatar
Michael Yang committed
40
	n := len(msgs) - 1
Michael Yang's avatar
Michael Yang committed
41
	// in reverse, find all messages that fit into context window
42
	for i := n; i >= 0; i-- {
43
		if isMllama && len(msgs[i].Images) > 1 {
44
45
46
47
48
49
50
51
			return "", nil, errTooManyImages
		}

		// always include the last message
		if i == n {
			continue
		}

Michael Yang's avatar
Michael Yang committed
52
53
54
55
56
57
58
		system = make([]api.Message, 0)
		for j := range i {
			if msgs[j].Role == "system" {
				system = append(system, msgs[j])
			}
		}

Michael Yang's avatar
Michael Yang committed
59
		var b bytes.Buffer
Michael Yang's avatar
tools  
Michael Yang committed
60
		if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
Michael Yang's avatar
Michael Yang committed
61
			return "", nil, err
62
63
		}

Michael Yang's avatar
Michael Yang committed
64
		s, err := tokenize(ctx, b.String())
65
		if err != nil {
Michael Yang's avatar
Michael Yang committed
66
			return "", nil, err
67
68
		}

69
		ctxLen := len(s)
Michael Yang's avatar
Michael Yang committed
70
		if m.ProjectorPaths != nil {
Michael Yang's avatar
Michael Yang committed
71
			for _, m := range msgs[i:] {
72
				ctxLen += imageNumTokens * len(m.Images)
Michael Yang's avatar
Michael Yang committed
73
			}
74
75
		}

76
		if ctxLen > opts.NumCtx {
Michael Yang's avatar
Michael Yang committed
77
			slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
78
			break
Michael Yang's avatar
Michael Yang committed
79
80
		} else {
			n = i
81
		}
Michael Yang's avatar
Michael Yang committed
82
	}
83

84
85
	currMsgIdx := n

86
87
88
89
90
91
92
93
94
	for cnt, msg := range msgs[currMsgIdx:] {
		prefix := ""
		imgPrompt := ""
		prompt := msg.Content

		for _, i := range msg.Images {
			var imgData llm.ImageData

			if isMllama {
95
				if len(m.ProjectorPaths) == 0 {
Jesse Gross's avatar
Jesse Gross committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
					imgData = llm.ImageData{
						ID:   len(images),
						Data: i,
					}
				} else {
					data, opts, err := mllama.Preprocess(bytes.NewReader(i))
					if err != nil {
						return "", nil, err
					}

					buf := new(bytes.Buffer)
					err = binary.Write(buf, binary.LittleEndian, data)
					if err != nil {
						return "", nil, err
					}

					ar, ok := opts["aspectRatioIndex"].(int)
					if !ok {
						return "", nil, fmt.Errorf("missing aspect ratio for image")
					}

					imgData = llm.ImageData{
						ID:            len(images),
						Data:          buf.Bytes(),
						AspectRatioID: ar,
					}
122
				}
123
124
125
				imgPrompt = "<|image|>"
			} else {
				imgData = llm.ImageData{
126
127
128
					ID:   len(images),
					Data: i,
				}
129
			}
130

131
132
133
134
135
			imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
			if !strings.Contains(prompt, "[img]") {
				prefix += imgTag
			} else {
				prompt = strings.Replace(prompt, "[img]", imgTag, 1)
136
			}
137
138

			images = append(images, imgData)
139
		}
140
		msgs[currMsgIdx+cnt].Content = prefix + imgPrompt + prompt
141
142
	}

Michael Yang's avatar
Michael Yang committed
143
	// truncate any messages that do not fit into the context window
Michael Yang's avatar
Michael Yang committed
144
	var b bytes.Buffer
145
	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
Michael Yang's avatar
Michael Yang committed
146
		return "", nil, err
147
148
	}

149
150
151
152
153
154
155
	return b.String(), images, nil
}

func checkMllamaModelFamily(m *Model) bool {
	for _, arch := range m.Config.ModelFamilies {
		if arch == "mllama" {
			return true
156
157
		}
	}
158
	return false
159
}