Commit 58e3fff3 authored by Michael Yang's avatar Michael Yang
Browse files

rename templates to template

parent 3f0b309a
package templates
package template
import (
"bytes"
......@@ -7,9 +7,14 @@ import (
"errors"
"io"
"math"
"slices"
"strings"
"sync"
"text/template"
"text/template/parse"
"github.com/agnivade/levenshtein"
"golang.org/x/exp/maps"
)
//go:embed index.json
......@@ -18,8 +23,8 @@ var indexBytes []byte
//go:embed *.gotmpl
var templatesFS embed.FS
var templatesOnce = sync.OnceValues(func() ([]*Template, error) {
var templates []*Template
var templatesOnce = sync.OnceValues(func() ([]*named, error) {
var templates []*named
if err := json.Unmarshal(indexBytes, &templates); err != nil {
return nil, err
}
......@@ -37,23 +42,23 @@ var templatesOnce = sync.OnceValues(func() ([]*Template, error) {
return templates, nil
})
type Template struct {
type named struct {
Name string `json:"name"`
Template string `json:"template"`
Bytes []byte
Bytes []byte
}
func (t Template) Reader() io.Reader {
func (t named) Reader() io.Reader {
return bytes.NewReader(t.Bytes)
}
func NamedTemplate(s string) (*Template, error) {
func Named(s string) (*named, error) {
templates, err := templatesOnce()
if err != nil {
return nil, err
}
var template *Template
var template *named
score := math.MaxInt
for _, t := range templates {
if s := levenshtein.ComputeDistance(s, t.Template); s < score {
......@@ -68,3 +73,86 @@ func NamedTemplate(s string) (*Template, error) {
return nil, errors.New("no matching template found")
}
type Template struct {
*template.Template
raw string
}
func (t *Template) String() string {
return t.raw
}
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
func Parse(s string) (*Template, error) {
t, err := template.New("").Option("missingkey=zero").Parse(s)
if err != nil {
return nil, err
}
return &Template{Template: t, raw: s}, nil
}
func (t *Template) Vars() []string {
var vars []string
for _, n := range t.Tree.Root.Nodes {
vars = append(vars, parseNode(n)...)
}
set := make(map[string]struct{})
for _, n := range vars {
set[strings.ToLower(n)] = struct{}{}
}
vars = maps.Keys(set)
slices.Sort(vars)
return vars
}
func parseNode(n parse.Node) []string {
switch n := n.(type) {
case *parse.ActionNode:
return parseNode(n.Pipe)
case *parse.IfNode:
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.RangeNode:
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.WithNode:
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.PipeNode:
var names []string
for _, c := range n.Cmds {
for _, a := range c.Args {
names = append(names, parseNode(a)...)
}
}
return names
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, parseNode(n)...)
}
return names
case *parse.FieldNode:
return n.Ident
}
return nil
}
package templates
package template
import (
"bufio"
......@@ -7,13 +7,14 @@ import (
"io"
"os"
"path/filepath"
"slices"
"testing"
"text/template"
"github.com/ollama/ollama/llm"
)
func TestKVChatTemplate(t *testing.T) {
func TestNamed(t *testing.T) {
f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
if err != nil {
t.Fatal(err)
......@@ -31,7 +32,7 @@ func TestKVChatTemplate(t *testing.T) {
t.Run(k, func(t *testing.T) {
kv := llm.KV{"tokenizer.chat_template": v}
s := kv.ChatTemplate()
r, err := NamedTemplate(s)
r, err := Named(s)
if err != nil {
t.Fatal(err)
}
......@@ -57,3 +58,32 @@ func TestKVChatTemplate(t *testing.T) {
}
}
}
func TestParse(t *testing.T) {
cases := []struct {
template string
capabilities []string
}{
{"{{ .Prompt }}", []string{"prompt"}},
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}},
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}},
}
for _, tt := range cases {
t.Run("", func(t *testing.T) {
tmpl, err := Parse(tt.template)
if err != nil {
t.Fatal(err)
}
vars := tmpl.Vars()
if !slices.Equal(tt.capabilities, vars) {
t.Errorf("expected %v, got %v", tt.capabilities, vars)
}
})
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment