prompt.go 3.3 KB
Newer Older
1
2
3
package server

import (
Michael Yang's avatar
Michael Yang committed
4
5
	"bytes"
	"context"
6
7
	"errors"
	"fmt"
8
	"log/slog"
9
	"slices"
10
	"strings"
11

12
	"github.com/ollama/ollama/api"
Michael Yang's avatar
Michael Yang committed
13
	"github.com/ollama/ollama/llm"
Devon Rifkin's avatar
Devon Rifkin committed
14
	"github.com/ollama/ollama/model/renderers"
Michael Yang's avatar
Michael Yang committed
15
	"github.com/ollama/ollama/template"
16
17
)

Michael Yang's avatar
Michael Yang committed
18
19
20
21
22
type tokenizeFunc func(context.Context, string) ([]int, error)

// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
23
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, _ error) {
Michael Yang's avatar
Michael Yang committed
24
	var system []api.Message
25

26
	// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
27
28
	// Clip images are represented as 768 tokens, each an embedding
	imageNumTokens := 768
29

Michael Yang's avatar
Michael Yang committed
30
	n := len(msgs) - 1
Michael Yang's avatar
Michael Yang committed
31
	// in reverse, find all messages that fit into context window
32
33
34
35
36
37
	for i := n; i >= 0; i-- {
		// always include the last message
		if i == n {
			continue
		}

Michael Yang's avatar
Michael Yang committed
38
39
40
41
42
43
44
		system = make([]api.Message, 0)
		for j := range i {
			if msgs[j].Role == "system" {
				system = append(system, msgs[j])
			}
		}

Devon Rifkin's avatar
Devon Rifkin committed
45
46
		p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
		if err != nil {
Michael Yang's avatar
Michael Yang committed
47
			return "", nil, err
48
49
		}

Devon Rifkin's avatar
Devon Rifkin committed
50
		s, err := tokenize(ctx, p)
51
		if err != nil {
Michael Yang's avatar
Michael Yang committed
52
			return "", nil, err
53
54
		}

55
		ctxLen := len(s)
Michael Yang's avatar
Michael Yang committed
56
		if m.ProjectorPaths != nil {
Michael Yang's avatar
Michael Yang committed
57
			for _, m := range msgs[i:] {
58
				ctxLen += imageNumTokens * len(m.Images)
Michael Yang's avatar
Michael Yang committed
59
			}
60
61
		}

62
		if truncate && ctxLen > opts.NumCtx {
Michael Yang's avatar
Michael Yang committed
63
			slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
64
			break
Michael Yang's avatar
Michael Yang committed
65
66
		} else {
			n = i
67
		}
Michael Yang's avatar
Michael Yang committed
68
	}
69

70
71
	currMsgIdx := n

72
	for cnt, msg := range msgs[currMsgIdx:] {
73
74
75
76
77
		if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
			return "", nil, errors.New("this model only supports one image while more than one image requested")
		}

		var prefix string
78
79
80
		prompt := msg.Content

		for _, i := range msg.Images {
81
82
83
			imgData := llm.ImageData{
				ID:   len(images),
				Data: i,
84
			}
85

86
87
88
89
90
			imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
			if !strings.Contains(prompt, "[img]") {
				prefix += imgTag
			} else {
				prompt = strings.Replace(prompt, "[img]", imgTag, 1)
91
			}
92
93

			images = append(images, imgData)
94
		}
95
		msgs[currMsgIdx+cnt].Content = prefix + prompt
96
97
	}

Michael Yang's avatar
Michael Yang committed
98
	// truncate any messages that do not fit into the context window
Devon Rifkin's avatar
Devon Rifkin committed
99
100
101
102
103
104
105
106
107
108
	p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think)
	if err != nil {
		return "", nil, err
	}

	return p, images, nil
}

func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
	if m.Config.Renderer != "" {
109
		rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think)
Devon Rifkin's avatar
Devon Rifkin committed
110
111
112
113
114
115
		if err != nil {
			return "", err
		}
		return rendered, nil
	}

Michael Yang's avatar
Michael Yang committed
116
	var b bytes.Buffer
117
	thinkVal := false
Michael Yang's avatar
Michael Yang committed
118
	thinkLevel := ""
119
	if think != nil {
120
121
		thinkVal = think.Bool()
		thinkLevel = think.String()
122
	}
Devon Rifkin's avatar
Devon Rifkin committed
123
124
	if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
		return "", err
125
	}
Devon Rifkin's avatar
Devon Rifkin committed
126
	return b.String(), nil
127
}