prompt.go 3.27 KB
Newer Older
1
2
3
package server

import (
Michael Yang's avatar
Michael Yang committed
4
5
	"bytes"
	"context"
6
7
	"errors"
	"fmt"
8
	"log/slog"
9
	"slices"
10
	"strings"
11

12
	"github.com/ollama/ollama/api"
Michael Yang's avatar
Michael Yang committed
13
	"github.com/ollama/ollama/llm"
Devon Rifkin's avatar
Devon Rifkin committed
14
	"github.com/ollama/ollama/model/renderers"
Michael Yang's avatar
Michael Yang committed
15
	"github.com/ollama/ollama/template"
16
17
)

Michael Yang's avatar
Michael Yang committed
18
19
20
21
22
type tokenizeFunc func(context.Context, string) ([]int, error)

// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
Michael Yang's avatar
Michael Yang committed
23
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (prompt string, images []llm.ImageData, _ error) {
Michael Yang's avatar
Michael Yang committed
24
	var system []api.Message
25

26
	// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
27
28
	// Clip images are represented as 768 tokens, each an embedding
	imageNumTokens := 768
29

Michael Yang's avatar
Michael Yang committed
30
	n := len(msgs) - 1
Michael Yang's avatar
Michael Yang committed
31
	// in reverse, find all messages that fit into context window
32
33
34
35
36
37
	for i := n; i >= 0; i-- {
		// always include the last message
		if i == n {
			continue
		}

Michael Yang's avatar
Michael Yang committed
38
39
40
41
42
43
44
		system = make([]api.Message, 0)
		for j := range i {
			if msgs[j].Role == "system" {
				system = append(system, msgs[j])
			}
		}

Devon Rifkin's avatar
Devon Rifkin committed
45
46
		p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
		if err != nil {
Michael Yang's avatar
Michael Yang committed
47
			return "", nil, err
48
49
		}

Devon Rifkin's avatar
Devon Rifkin committed
50
		s, err := tokenize(ctx, p)
51
		if err != nil {
Michael Yang's avatar
Michael Yang committed
52
			return "", nil, err
53
54
		}

55
		ctxLen := len(s)
Michael Yang's avatar
Michael Yang committed
56
		if m.ProjectorPaths != nil {
Michael Yang's avatar
Michael Yang committed
57
			for _, m := range msgs[i:] {
58
				ctxLen += imageNumTokens * len(m.Images)
Michael Yang's avatar
Michael Yang committed
59
			}
60
61
		}

62
		if ctxLen > opts.NumCtx {
Michael Yang's avatar
Michael Yang committed
63
			slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
64
			break
Michael Yang's avatar
Michael Yang committed
65
66
		} else {
			n = i
67
		}
Michael Yang's avatar
Michael Yang committed
68
	}
69

70
71
	currMsgIdx := n

72
	for cnt, msg := range msgs[currMsgIdx:] {
73
74
75
76
77
		if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
			return "", nil, errors.New("this model only supports one image while more than one image requested")
		}

		var prefix string
78
79
80
		prompt := msg.Content

		for _, i := range msg.Images {
81
82
83
			imgData := llm.ImageData{
				ID:   len(images),
				Data: i,
84
			}
85

86
87
88
89
90
			imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
			if !strings.Contains(prompt, "[img]") {
				prefix += imgTag
			} else {
				prompt = strings.Replace(prompt, "[img]", imgTag, 1)
91
			}
92
93

			images = append(images, imgData)
94
		}
95
		msgs[currMsgIdx+cnt].Content = prefix + prompt
96
97
	}

Michael Yang's avatar
Michael Yang committed
98
	// truncate any messages that do not fit into the context window
Devon Rifkin's avatar
Devon Rifkin committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
	p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think)
	if err != nil {
		return "", nil, err
	}

	return p, images, nil
}

func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
	if m.Config.Renderer != "" {
		rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think)
		if err != nil {
			return "", err
		}
		return rendered, nil
	}

Michael Yang's avatar
Michael Yang committed
116
	var b bytes.Buffer
117
	thinkVal := false
Michael Yang's avatar
Michael Yang committed
118
	thinkLevel := ""
119
	if think != nil {
120
121
		thinkVal = think.Bool()
		thinkLevel = think.String()
122
	}
Devon Rifkin's avatar
Devon Rifkin committed
123
124
	if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
		return "", err
125
	}
Devon Rifkin's avatar
Devon Rifkin committed
126
	return b.String(), nil
127
}