prompt.go 3.69 KB
Newer Older
1
2
3
package server

import (
Michael Yang's avatar
Michael Yang committed
4
5
	"bytes"
	"context"
6
7
8
	"encoding/binary"
	"errors"
	"fmt"
9
	"log/slog"
10
	"strings"
11

12
	"github.com/ollama/ollama/api"
Michael Yang's avatar
Michael Yang committed
13
	"github.com/ollama/ollama/llm"
14
	"github.com/ollama/ollama/model/mllama"
Michael Yang's avatar
Michael Yang committed
15
	"github.com/ollama/ollama/template"
16
17
)

Michael Yang's avatar
Michael Yang committed
18
19
type tokenizeFunc func(context.Context, string) ([]int, error)

20
21
var errTooManyImages = errors.New("vision model only supports a single image per message")

Michael Yang's avatar
Michael Yang committed
22
23
24
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
Michael Yang's avatar
tools  
Michael Yang committed
25
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
Michael Yang's avatar
Michael Yang committed
26
	var system []api.Message
27
28
29

	isMllama := checkMllamaModelFamily(m)

30
31
32
33
34
35
36
37
38
39
	var imageNumTokens int
	// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
	if isMllama {
		// Our mllama implementation packs all of the embeddings into a single token
		imageNumTokens = 1
	} else {
		// Clip images are represented as 768 tokens, each an embedding
		imageNumTokens = 768
	}

Michael Yang's avatar
Michael Yang committed
40
	n := len(msgs) - 1
Michael Yang's avatar
Michael Yang committed
41
	// in reverse, find all messages that fit into context window
42
43
44
45
46
47
48
49
50
51
	for i := n; i >= 0; i-- {
		if isMllama && len(msgs[i].Images) > 1 {
			return "", nil, errTooManyImages
		}

		// always include the last message
		if i == n {
			continue
		}

Michael Yang's avatar
Michael Yang committed
52
53
54
55
56
57
58
		system = make([]api.Message, 0)
		for j := range i {
			if msgs[j].Role == "system" {
				system = append(system, msgs[j])
			}
		}

Michael Yang's avatar
Michael Yang committed
59
		var b bytes.Buffer
Michael Yang's avatar
tools  
Michael Yang committed
60
		if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
Michael Yang's avatar
Michael Yang committed
61
			return "", nil, err
62
63
		}

Michael Yang's avatar
Michael Yang committed
64
		s, err := tokenize(ctx, b.String())
65
		if err != nil {
Michael Yang's avatar
Michael Yang committed
66
			return "", nil, err
67
68
		}

69
		ctxLen := len(s)
Michael Yang's avatar
Michael Yang committed
70
		if m.ProjectorPaths != nil {
Michael Yang's avatar
Michael Yang committed
71
			for _, m := range msgs[i:] {
72
				ctxLen += imageNumTokens * len(m.Images)
Michael Yang's avatar
Michael Yang committed
73
			}
74
75
		}

76
		if ctxLen > opts.NumCtx {
Michael Yang's avatar
Michael Yang committed
77
			slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
78
			break
Michael Yang's avatar
Michael Yang committed
79
80
		} else {
			n = i
81
		}
Michael Yang's avatar
Michael Yang committed
82
	}
83

84
85
	currMsgIdx := n

86
87
88
89
90
91
92
93
94
	for cnt, msg := range msgs[currMsgIdx:] {
		prefix := ""
		imgPrompt := ""
		prompt := msg.Content

		for _, i := range msg.Images {
			var imgData llm.ImageData

			if isMllama {
95
				data, opts, err := mllama.Preprocess(bytes.NewReader(i))
96
97
98
99
100
101
102
103
104
105
				if err != nil {
					return "", nil, err
				}

				buf := new(bytes.Buffer)
				err = binary.Write(buf, binary.LittleEndian, data)
				if err != nil {
					return "", nil, err
				}

106
107
108
109
110
				ar, ok := opts["aspectRatioIndex"].(int)
				if !ok {
					return "", nil, fmt.Errorf("missing aspect ratio for image")
				}

111
112
				imgData = llm.ImageData{
					ID:            len(images),
113
					Data:          buf.Bytes(),
114
					AspectRatioID: ar,
115
				}
116
117
118
				imgPrompt = "<|image|>"
			} else {
				imgData = llm.ImageData{
119
120
121
					ID:   len(images),
					Data: i,
				}
122
			}
123

124
125
126
127
128
			imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
			if !strings.Contains(prompt, "[img]") {
				prefix += imgTag
			} else {
				prompt = strings.Replace(prompt, "[img]", imgTag, 1)
129
			}
130
131

			images = append(images, imgData)
132
		}
133
		msgs[currMsgIdx+cnt].Content = prefix + imgPrompt + prompt
134
135
	}

Michael Yang's avatar
Michael Yang committed
136
	// truncate any messages that do not fit into the context window
Michael Yang's avatar
Michael Yang committed
137
	var b bytes.Buffer
138
	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
Michael Yang's avatar
Michael Yang committed
139
		return "", nil, err
140
141
	}

142
143
144
145
146
147
148
	return b.String(), images, nil
}

func checkMllamaModelFamily(m *Model) bool {
	for _, arch := range m.Config.ModelFamilies {
		if arch == "mllama" {
			return true
149
150
		}
	}
151
	return false
152
}