prompt.go 3.61 KB
Newer Older
mashun1's avatar
v1  
mashun1 committed
1
2
3
package server

import (
xuxzh1's avatar
init  
xuxzh1 committed
4
5
	"bytes"
	"context"
xuxzh1's avatar
update  
xuxzh1 committed
6
7
8
	"encoding/binary"
	"errors"
	"fmt"
mashun1's avatar
v1  
mashun1 committed
9
	"log/slog"
xuxzh1's avatar
update  
xuxzh1 committed
10
	"strings"
mashun1's avatar
v1  
mashun1 committed
11
12

	"github.com/ollama/ollama/api"
xuxzh1's avatar
init  
xuxzh1 committed
13
	"github.com/ollama/ollama/llm"
xuxzh1's avatar
update  
xuxzh1 committed
14
	"github.com/ollama/ollama/server/imageproc"
xuxzh1's avatar
init  
xuxzh1 committed
15
	"github.com/ollama/ollama/template"
mashun1's avatar
v1  
mashun1 committed
16
17
)

xuxzh1's avatar
init  
xuxzh1 committed
18
19
type tokenizeFunc func(context.Context, string) ([]int, error)

xuxzh1's avatar
update  
xuxzh1 committed
20
21
var errTooManyImages = errors.New("vision model only supports a single image per message")

xuxzh1's avatar
init  
xuxzh1 committed
22
23
24
25
26
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
	var system []api.Message
xuxzh1's avatar
update  
xuxzh1 committed
27
28
29
30
31
32
33
34
35
36
37
38
39

	isMllama := checkMllamaModelFamily(m)

	var imageNumTokens int
	// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
	if isMllama {
		// Our mllama implementation packs all of the embeddings into a single token
		imageNumTokens = 1
	} else {
		// Clip images are represented as 768 tokens, each an embedding
		imageNumTokens = 768
	}

xuxzh1's avatar
init  
xuxzh1 committed
40
41
	n := len(msgs) - 1
	// in reverse, find all messages that fit into context window
xuxzh1's avatar
update  
xuxzh1 committed
42
43
44
45
46
47
48
49
50
51
	for i := n; i >= 0; i-- {
		if isMllama && len(msgs[i].Images) > 1 {
			return "", nil, errTooManyImages
		}

		// always include the last message
		if i == n {
			continue
		}

xuxzh1's avatar
init  
xuxzh1 committed
52
53
54
55
		system = make([]api.Message, 0)
		for j := range i {
			if msgs[j].Role == "system" {
				system = append(system, msgs[j])
mashun1's avatar
v1  
mashun1 committed
56
57
58
			}
		}

xuxzh1's avatar
init  
xuxzh1 committed
59
60
61
		var b bytes.Buffer
		if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
			return "", nil, err
mashun1's avatar
v1  
mashun1 committed
62
63
		}

xuxzh1's avatar
init  
xuxzh1 committed
64
		s, err := tokenize(ctx, b.String())
mashun1's avatar
v1  
mashun1 committed
65
		if err != nil {
xuxzh1's avatar
init  
xuxzh1 committed
66
			return "", nil, err
mashun1's avatar
v1  
mashun1 committed
67
68
		}

xuxzh1's avatar
update  
xuxzh1 committed
69
		ctxLen := len(s)
xuxzh1's avatar
init  
xuxzh1 committed
70
71
		if m.ProjectorPaths != nil {
			for _, m := range msgs[i:] {
xuxzh1's avatar
update  
xuxzh1 committed
72
				ctxLen += imageNumTokens * len(m.Images)
xuxzh1's avatar
init  
xuxzh1 committed
73
			}
mashun1's avatar
v1  
mashun1 committed
74
75
		}

xuxzh1's avatar
update  
xuxzh1 committed
76
		if ctxLen > opts.NumCtx {
xuxzh1's avatar
init  
xuxzh1 committed
77
			slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
mashun1's avatar
v1  
mashun1 committed
78
			break
xuxzh1's avatar
init  
xuxzh1 committed
79
80
		} else {
			n = i
mashun1's avatar
v1  
mashun1 committed
81
		}
xuxzh1's avatar
init  
xuxzh1 committed
82
	}
mashun1's avatar
v1  
mashun1 committed
83

xuxzh1's avatar
update  
xuxzh1 committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
	currMsgIdx := n

	for cnt, msg := range msgs[currMsgIdx:] {
		prefix := ""
		imgPrompt := ""
		prompt := msg.Content

		for _, i := range msg.Images {
			var imgData llm.ImageData

			if isMllama {
				data, aspectRatioID, err := imageproc.Preprocess(i)
				if err != nil {
					return "", nil, err
				}

				buf := new(bytes.Buffer)
				err = binary.Write(buf, binary.LittleEndian, data)
				if err != nil {
					return "", nil, err
				}

				imgData = llm.ImageData{
					ID:            len(images),
					Data:          buf.Bytes(),
					AspectRatioID: aspectRatioID,
				}
				imgPrompt = "<|image|>"
			} else {
				imgData = llm.ImageData{
					ID:   len(images),
					Data: i,
				}
				imgPrompt = " "
			}

			imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
			if !strings.Contains(prompt, "[img]") {
				prefix += imgTag
			} else {
				prompt = strings.Replace(prompt, "[img]", imgTag, 1)
			}

			images = append(images, imgData)
		}
		msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + imgPrompt + prompt)
	}

xuxzh1's avatar
init  
xuxzh1 committed
132
133
	// truncate any messages that do not fit into the context window
	var b bytes.Buffer
xuxzh1's avatar
update  
xuxzh1 committed
134
	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
xuxzh1's avatar
init  
xuxzh1 committed
135
		return "", nil, err
mashun1's avatar
v1  
mashun1 committed
136
137
	}

xuxzh1's avatar
update  
xuxzh1 committed
138
139
140
141
142
143
144
	return b.String(), images, nil
}

func checkMllamaModelFamily(m *Model) bool {
	for _, arch := range m.Config.ModelFamilies {
		if arch == "mllama" {
			return true
mashun1's avatar
v1  
mashun1 committed
145
146
		}
	}
xuxzh1's avatar
update  
xuxzh1 committed
147
	return false
mashun1's avatar
v1  
mashun1 committed
148
}