template.go 10.3 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
4
5
6
7
8
package template

import (
	"bytes"
	"embed"
	"encoding/json"
	"errors"
	"io"
Michael Yang's avatar
Michael Yang committed
9
	"maps"
Michael Yang's avatar
Michael Yang committed
10
11
12
13
14
15
	"math"
	"slices"
	"strings"
	"sync"
	"text/template"
	"text/template/parse"
Michael Yang's avatar
Michael Yang committed
16
	"time"
Michael Yang's avatar
Michael Yang committed
17
18

	"github.com/agnivade/levenshtein"
Michael Yang's avatar
lint  
Michael Yang committed
19
20

	"github.com/ollama/ollama/api"
Michael Yang's avatar
Michael Yang committed
21
22
23
24
25
26
)

//go:embed index.json
var indexBytes []byte

//go:embed *.gotmpl
27
//go:embed *.json
Michael Yang's avatar
Michael Yang committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
var templatesFS embed.FS

var templatesOnce = sync.OnceValues(func() ([]*named, error) {
	var templates []*named
	if err := json.Unmarshal(indexBytes, &templates); err != nil {
		return nil, err
	}

	for _, t := range templates {
		bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
		if err != nil {
			return nil, err
		}

		// normalize line endings
		t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
44
45
46
47
48
49
50
51
52

		params, err := templatesFS.ReadFile(t.Name + ".json")
		if err != nil {
			continue
		}

		if err := json.Unmarshal(params, &t.Parameters); err != nil {
			return nil, err
		}
Michael Yang's avatar
Michael Yang committed
53
54
55
56
57
58
59
60
61
	}

	return templates, nil
})

type named struct {
	Name     string `json:"name"`
	Template string `json:"template"`
	Bytes    []byte
62
63
64
65

	Parameters *struct {
		Stop []string `json:"stop"`
	}
Michael Yang's avatar
Michael Yang committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
}

func (t named) Reader() io.Reader {
	return bytes.NewReader(t.Bytes)
}

func Named(s string) (*named, error) {
	templates, err := templatesOnce()
	if err != nil {
		return nil, err
	}

	var template *named
	score := math.MaxInt
	for _, t := range templates {
		if s := levenshtein.ComputeDistance(s, t.Template); s < score {
			score = s
			template = t
		}
	}

	if score < 100 {
		return template, nil
	}

	return nil, errors.New("no matching template found")
}

Michael Yang's avatar
Michael Yang committed
94
95
var DefaultTemplate, _ = Parse("{{ .Prompt }}")

Michael Yang's avatar
Michael Yang committed
96
97
98
99
100
type Template struct {
	*template.Template
	raw string
}

Michael Yang's avatar
Michael Yang committed
101
// response is a template node that can be added to templates that don't already have one
Michael Yang's avatar
Michael Yang committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
var response = parse.ActionNode{
	NodeType: parse.NodeAction,
	Pipe: &parse.PipeNode{
		NodeType: parse.NodePipe,
		Cmds: []*parse.CommandNode{
			{
				NodeType: parse.NodeCommand,
				Args: []parse.Node{
					&parse.FieldNode{
						NodeType: parse.NodeField,
						Ident:    []string{"Response"},
					},
				},
			},
		},
	},
Michael Yang's avatar
Michael Yang committed
118
119
}

Michael Yang's avatar
tools  
Michael Yang committed
120
121
122
123
124
var funcs = template.FuncMap{
	"json": func(v any) string {
		b, _ := json.Marshal(v)
		return string(b)
	},
Michael Yang's avatar
Michael Yang committed
125
126
127
128
129
	"currentDate": func(args ...string) string {
		// Currently ignoring the format argument, but accepting it for future use
		// Default format is YYYY-MM-DD
		return time.Now().Format("2006-01-02")
	},
Devon Rifkin's avatar
Devon Rifkin committed
130
131
132
133
134
135
136
137
138
139
	"toTypeScriptType": func(v any) string {
		if param, ok := v.(api.ToolProperty); ok {
			return param.ToTypeScriptType()
		}
		// Handle pointer case
		if param, ok := v.(*api.ToolProperty); ok && param != nil {
			return param.ToTypeScriptType()
		}
		return "any"
	},
Michael Yang's avatar
tools  
Michael Yang committed
140
141
}

Michael Yang's avatar
Michael Yang committed
142
func Parse(s string) (*Template, error) {
Michael Yang's avatar
tools  
Michael Yang committed
143
	tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
Michael Yang's avatar
Michael Yang committed
144
145

	tmpl, err := tmpl.Parse(s)
Michael Yang's avatar
Michael Yang committed
146
147
148
149
	if err != nil {
		return nil, err
	}

Michael Yang's avatar
Michael Yang committed
150
151
152
153
154
155
156
157
158
159
160
	t := Template{Template: tmpl, raw: s}
	if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
		// touch up the template and append {{ .Response }}
		tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
	}

	return &t, nil
}

func (t *Template) String() string {
	return t.raw
Michael Yang's avatar
Michael Yang committed
161
162
163
164
}

func (t *Template) Vars() []string {
	var vars []string
Michael Yang's avatar
Michael Yang committed
165
166
	for _, tt := range t.Templates() {
		for _, n := range tt.Root.Nodes {
Michael Yang's avatar
tools  
Michael Yang committed
167
			vars = append(vars, Identifiers(n)...)
Michael Yang's avatar
Michael Yang committed
168
		}
Michael Yang's avatar
Michael Yang committed
169
170
171
172
173
174
175
	}

	set := make(map[string]struct{})
	for _, n := range vars {
		set[strings.ToLower(n)] = struct{}{}
	}

Michael Yang's avatar
Michael Yang committed
176
	return slices.Sorted(maps.Keys(set))
Michael Yang's avatar
Michael Yang committed
177
178
}

Michael Yang's avatar
Michael Yang committed
179
180
181
182
func (t *Template) Contains(s string) bool {
	return strings.Contains(t.raw, s)
}

Michael Yang's avatar
Michael Yang committed
183
184
type Values struct {
	Messages []api.Message
185
186
187
	api.Tools
	Prompt string
	Suffix string
188
	Think  bool
Michael Yang's avatar
Michael Yang committed
189
190
	// ThinkLevel contains the thinking level if Think is true and a string value was provided
	ThinkLevel string
191
192
193
	// whether or not the user explicitly set the thinking flag (vs. it being
	// implicitly false). Templates can't see whether `Think` is nil
	IsThinkSet bool
194
195
196

	// forceLegacy is a flag used to test compatibility with legacy templates
	forceLegacy bool
Michael Yang's avatar
Michael Yang committed
197
198
}

Michael Yang's avatar
tools  
Michael Yang committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
	var walk func(parse.Node) parse.Node
	walk = func(n parse.Node) parse.Node {
		if fn(n) {
			return n
		}

		switch t := n.(type) {
		case *parse.ListNode:
			for _, c := range t.Nodes {
				if n := walk(c); n != nil {
					return n
				}
			}
		case *parse.BranchNode:
			for _, n := range []*parse.ListNode{t.List, t.ElseList} {
				if n != nil {
					if n := walk(n); n != nil {
						return n
					}
				}
			}
		case *parse.IfNode:
			return walk(&t.BranchNode)
		case *parse.WithNode:
			return walk(&t.BranchNode)
		case *parse.RangeNode:
			return walk(&t.BranchNode)
		}

		return nil
	}

	if n := walk(t.Tree.Root); n != nil {
		return (&template.Template{
			Tree: &parse.Tree{
				Root: &parse.ListNode{
					Nodes: []parse.Node{n},
				},
			},
		}).Funcs(funcs)
	}

	return nil
}

Michael Yang's avatar
Michael Yang committed
245
func (t *Template) Execute(w io.Writer, v Values) error {
Michael Yang's avatar
Michael Yang committed
246
	system, messages := collate(v.Messages)
247
248
	if v.Prompt != "" && v.Suffix != "" {
		return t.Template.Execute(w, map[string]any{
249
250
251
252
			"Prompt":     v.Prompt,
			"Suffix":     v.Suffix,
			"Response":   "",
			"Think":      v.Think,
Michael Yang's avatar
Michael Yang committed
253
			"ThinkLevel": v.ThinkLevel,
254
			"IsThinkSet": v.IsThinkSet,
255
256
		})
	} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
Michael Yang's avatar
Michael Yang committed
257
		return t.Template.Execute(w, map[string]any{
258
259
260
261
262
			"System":     system,
			"Messages":   messages,
			"Tools":      v.Tools,
			"Response":   "",
			"Think":      v.Think,
Michael Yang's avatar
Michael Yang committed
263
			"ThinkLevel": v.ThinkLevel,
264
			"IsThinkSet": v.IsThinkSet,
Michael Yang's avatar
Michael Yang committed
265
266
267
		})
	}

Michael Yang's avatar
Michael Yang committed
268
	system = ""
Michael Yang's avatar
Michael Yang committed
269
	var b bytes.Buffer
270
	var prompt, response string
Michael Yang's avatar
Michael Yang committed
271
	for _, m := range messages {
Michael Yang's avatar
tools  
Michael Yang committed
272
		execute := func() error {
Michael Yang's avatar
Michael Yang committed
273
			if err := t.Template.Execute(&b, map[string]any{
274
275
276
277
				"System":     system,
				"Prompt":     prompt,
				"Response":   response,
				"Think":      v.Think,
Michael Yang's avatar
Michael Yang committed
278
				"ThinkLevel": v.ThinkLevel,
279
				"IsThinkSet": v.IsThinkSet,
Michael Yang's avatar
Michael Yang committed
280
281
282
283
			}); err != nil {
				return err
			}

284
			system = ""
Michael Yang's avatar
Michael Yang committed
285
286
			prompt = ""
			response = ""
Michael Yang's avatar
Michael Yang committed
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
			return nil
		}

		switch m.Role {
		case "system":
			if prompt != "" || response != "" {
				if err := execute(); err != nil {
					return err
				}
			}
			system = m.Content
		case "user":
			if response != "" {
				if err := execute(); err != nil {
					return err
				}
			}
			prompt = m.Content
		case "assistant":
			response = m.Content
Michael Yang's avatar
Michael Yang committed
307
308
309
310
		}
	}

	var cut bool
311
	nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
Michael Yang's avatar
tools  
Michael Yang committed
312
313
		if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
			cut = true
314
			return false
Michael Yang's avatar
Michael Yang committed
315
316
317
318
319
		}

		return cut
	})

320
321
	tree := parse.Tree{Root: nodes.(*parse.ListNode)}
	if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
322
323
324
325
		"System":     system,
		"Prompt":     prompt,
		"Response":   response,
		"Think":      v.Think,
Michael Yang's avatar
Michael Yang committed
326
		"ThinkLevel": v.ThinkLevel,
327
		"IsThinkSet": v.IsThinkSet,
Michael Yang's avatar
Michael Yang committed
328
329
330
331
332
333
334
335
	}); err != nil {
		return err
	}

	_, err := io.Copy(w, &b)
	return err
}

Michael Yang's avatar
Michael Yang committed
336
// collate messages based on role. consecutive messages of the same role are merged
337
338
// into a single message (except for tool messages which preserve individual metadata).
// collate also collects and returns all system messages.
339
// collate mutates message content adding image tags ([img-%d]) as needed
340
// todo(parthsareen): revisit for contextual image support
341
342
343
func collate(msgs []api.Message) (string, []*api.Message) {
	var system []string
	var collated []*api.Message
Michael Yang's avatar
Michael Yang committed
344
	for i := range msgs {
345
346
		if msgs[i].Role == "system" {
			system = append(system, msgs[i].Content)
347
348
		}

349
350
351
		// merges consecutive messages of the same role into a single message (except for tool messages)
		if len(collated) > 0 && collated[len(collated)-1].Role == msgs[i].Role && msgs[i].Role != "tool" {
			collated[len(collated)-1].Content += "\n\n" + msgs[i].Content
Michael Yang's avatar
Michael Yang committed
352
		} else {
353
			collated = append(collated, &msgs[i])
Michael Yang's avatar
Michael Yang committed
354
355
356
		}
	}

357
	return strings.Join(system, "\n\n"), collated
Michael Yang's avatar
Michael Yang committed
358
359
}

Michael Yang's avatar
tools  
Michael Yang committed
360
361
// Identifiers walks the node tree returning any identifiers it finds along the way
func Identifiers(n parse.Node) []string {
Michael Yang's avatar
Michael Yang committed
362
	switch n := n.(type) {
Michael Yang's avatar
tools  
Michael Yang committed
363
364
365
366
	case *parse.ListNode:
		var names []string
		for _, n := range n.Nodes {
			names = append(names, Identifiers(n)...)
Michael Yang's avatar
Michael Yang committed
367
		}
Michael Yang's avatar
tools  
Michael Yang committed
368

Michael Yang's avatar
Michael Yang committed
369
		return names
Michael Yang's avatar
tools  
Michael Yang committed
370
371
372
373
374
375
376
377
378
379
	case *parse.TemplateNode:
		return Identifiers(n.Pipe)
	case *parse.ActionNode:
		return Identifiers(n.Pipe)
	case *parse.BranchNode:
		names := Identifiers(n.Pipe)
		for _, n := range []*parse.ListNode{n.List, n.ElseList} {
			if n != nil {
				names = append(names, Identifiers(n)...)
			}
Michael Yang's avatar
Michael Yang committed
380
381
		}
		return names
Michael Yang's avatar
tools  
Michael Yang committed
382
383
384
385
	case *parse.IfNode:
		return Identifiers(&n.BranchNode)
	case *parse.RangeNode:
		return Identifiers(&n.BranchNode)
Michael Yang's avatar
Michael Yang committed
386
	case *parse.WithNode:
Michael Yang's avatar
tools  
Michael Yang committed
387
		return Identifiers(&n.BranchNode)
Michael Yang's avatar
Michael Yang committed
388
389
390
391
	case *parse.PipeNode:
		var names []string
		for _, c := range n.Cmds {
			for _, a := range c.Args {
Michael Yang's avatar
tools  
Michael Yang committed
392
				names = append(names, Identifiers(a)...)
Michael Yang's avatar
Michael Yang committed
393
394
395
396
397
			}
		}
		return names
	case *parse.FieldNode:
		return n.Ident
Michael Yang's avatar
tools  
Michael Yang committed
398
399
	case *parse.VariableNode:
		return n.Ident
Michael Yang's avatar
Michael Yang committed
400
401
402
403
	}

	return nil
}
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472

// deleteNode walks the node list and deletes nodes that match the predicate
// this is currently to remove the {{ .Response }} node from templates
func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
	var walk func(n parse.Node) parse.Node
	walk = func(n parse.Node) parse.Node {
		if fn(n) {
			return nil
		}

		switch t := n.(type) {
		case *parse.ListNode:
			var nodes []parse.Node
			for _, c := range t.Nodes {
				if n := walk(c); n != nil {
					nodes = append(nodes, n)
				}
			}

			t.Nodes = nodes
			return t
		case *parse.IfNode:
			t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
		case *parse.WithNode:
			t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
		case *parse.RangeNode:
			t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
		case *parse.BranchNode:
			t.List = walk(t.List).(*parse.ListNode)
			if t.ElseList != nil {
				t.ElseList = walk(t.ElseList).(*parse.ListNode)
			}
		case *parse.ActionNode:
			n := walk(t.Pipe)
			if n == nil {
				return nil
			}

			t.Pipe = n.(*parse.PipeNode)
		case *parse.PipeNode:
			var commands []*parse.CommandNode
			for _, c := range t.Cmds {
				var args []parse.Node
				for _, a := range c.Args {
					if n := walk(a); n != nil {
						args = append(args, n)
					}
				}

				if len(args) == 0 {
					return nil
				}

				c.Args = args
				commands = append(commands, c)
			}

			if len(commands) == 0 {
				return nil
			}

			t.Cmds = commands
		}

		return n
	}

	return walk(n)
}