prompt.go 2.23 KB
Newer Older
1
2
3
package server

import (
Michael Yang's avatar
Michael Yang committed
4
5
	"bytes"
	"context"
6
	"log/slog"
Michael Yang's avatar
Michael Yang committed
7
	"slices"
8

9
	"github.com/ollama/ollama/api"
Michael Yang's avatar
Michael Yang committed
10
	"github.com/ollama/ollama/llm"
Michael Yang's avatar
Michael Yang committed
11
	"github.com/ollama/ollama/template"
12
13
)

Michael Yang's avatar
Michael Yang committed
14
15
16
17
18
19
20
type tokenizeFunc func(context.Context, string) ([]int, error)

// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
	// pull out any system messages which should always be included in the prompt
Michael Yang's avatar
Michael Yang committed
21
22
23
24
25
	var system []api.Message
	msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
		if m.Role == "system" {
			system = append(system, m)
			return true
26
27
		}

Michael Yang's avatar
Michael Yang committed
28
29
		return false
	})
30

Michael Yang's avatar
Michael Yang committed
31
	if len(system) == 0 && m.System != "" {
Michael Yang's avatar
Michael Yang committed
32
		// add model system prompt since it wasn't provided
Michael Yang's avatar
Michael Yang committed
33
		system = append(system, api.Message{Role: "system", Content: m.System})
34
35
	}

Michael Yang's avatar
Michael Yang committed
36
	// always include the last message
Michael Yang's avatar
Michael Yang committed
37
	n := len(msgs) - 1
Michael Yang's avatar
Michael Yang committed
38
	// in reverse, find all messages that fit into context window
Michael Yang's avatar
Michael Yang committed
39
40
	for i := n - 1; i >= 0; i-- {
		var b bytes.Buffer
Michael Yang's avatar
Michael Yang committed
41
		if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
Michael Yang's avatar
Michael Yang committed
42
			return "", nil, err
43
44
		}

Michael Yang's avatar
Michael Yang committed
45
		s, err := tokenize(ctx, b.String())
46
		if err != nil {
Michael Yang's avatar
Michael Yang committed
47
			return "", nil, err
48
49
		}

Michael Yang's avatar
Michael Yang committed
50
		c := len(s)
Michael Yang's avatar
Michael Yang committed
51
		if m.ProjectorPaths != nil {
Michael Yang's avatar
Michael Yang committed
52
			for _, m := range msgs[i:] {
Michael Yang's avatar
Michael Yang committed
53
54
				// images are represented as 768 sized embeddings
				// TODO: get embedding length from project metadata
Michael Yang's avatar
Michael Yang committed
55
56
				c += 768 * len(m.Images)
			}
57
58
		}

Michael Yang's avatar
Michael Yang committed
59
		if c > opts.NumCtx {
Michael Yang's avatar
Michael Yang committed
60
			slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
61
			break
Michael Yang's avatar
Michael Yang committed
62
63
		} else {
			n = i
64
		}
Michael Yang's avatar
Michael Yang committed
65
	}
66

Michael Yang's avatar
Michael Yang committed
67
	// truncate any messages that do not fit into the context window
Michael Yang's avatar
Michael Yang committed
68
	var b bytes.Buffer
Michael Yang's avatar
Michael Yang committed
69
	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
Michael Yang's avatar
Michael Yang committed
70
		return "", nil, err
71
72
	}

Michael Yang's avatar
Michael Yang committed
73
74
75
76
77
78
	for _, m := range msgs[n:] {
		for _, i := range m.Images {
			images = append(images, llm.ImageData{
				ID:   len(images),
				Data: i,
			})
79
80
81
		}
	}

Michael Yang's avatar
Michael Yang committed
82
	return b.String(), images, nil
83
}