prompt.go 2 KB
Newer Older
mashun1's avatar
v1  
mashun1 committed
1
2
3
package server

import (
xuxzh1's avatar
init  
xuxzh1 committed
4
5
	"bytes"
	"context"
mashun1's avatar
v1  
mashun1 committed
6
7
8
	"log/slog"

	"github.com/ollama/ollama/api"
xuxzh1's avatar
init  
xuxzh1 committed
9
10
	"github.com/ollama/ollama/llm"
	"github.com/ollama/ollama/template"
mashun1's avatar
v1  
mashun1 committed
11
12
)

xuxzh1's avatar
init  
xuxzh1 committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
type tokenizeFunc func(context.Context, string) ([]int, error)

// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
	var system []api.Message
	// always include the last message
	n := len(msgs) - 1
	// in reverse, find all messages that fit into context window
	for i := n - 1; i >= 0; i-- {
		system = make([]api.Message, 0)
		for j := range i {
			if msgs[j].Role == "system" {
				system = append(system, msgs[j])
mashun1's avatar
v1  
mashun1 committed
28
29
30
			}
		}

xuxzh1's avatar
init  
xuxzh1 committed
31
32
33
		var b bytes.Buffer
		if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
			return "", nil, err
mashun1's avatar
v1  
mashun1 committed
34
35
		}

xuxzh1's avatar
init  
xuxzh1 committed
36
		s, err := tokenize(ctx, b.String())
mashun1's avatar
v1  
mashun1 committed
37
		if err != nil {
xuxzh1's avatar
init  
xuxzh1 committed
38
			return "", nil, err
mashun1's avatar
v1  
mashun1 committed
39
40
		}

xuxzh1's avatar
init  
xuxzh1 committed
41
42
43
44
45
46
47
		c := len(s)
		if m.ProjectorPaths != nil {
			for _, m := range msgs[i:] {
				// images are represented as 768 sized embeddings
				// TODO: get embedding length from project metadata
				c += 768 * len(m.Images)
			}
mashun1's avatar
v1  
mashun1 committed
48
49
		}

xuxzh1's avatar
init  
xuxzh1 committed
50
51
		if c > opts.NumCtx {
			slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
mashun1's avatar
v1  
mashun1 committed
52
			break
xuxzh1's avatar
init  
xuxzh1 committed
53
54
		} else {
			n = i
mashun1's avatar
v1  
mashun1 committed
55
		}
xuxzh1's avatar
init  
xuxzh1 committed
56
	}
mashun1's avatar
v1  
mashun1 committed
57

xuxzh1's avatar
init  
xuxzh1 committed
58
59
60
61
	// truncate any messages that do not fit into the context window
	var b bytes.Buffer
	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
		return "", nil, err
mashun1's avatar
v1  
mashun1 committed
62
63
	}

xuxzh1's avatar
init  
xuxzh1 committed
64
65
66
67
68
69
	for _, m := range msgs[n:] {
		for _, i := range m.Images {
			images = append(images, llm.ImageData{
				ID:   len(images),
				Data: i,
			})
mashun1's avatar
v1  
mashun1 committed
70
71
72
		}
	}

xuxzh1's avatar
init  
xuxzh1 committed
73
	return b.String(), images, nil
mashun1's avatar
v1  
mashun1 committed
74
}