template.go 3.61 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
package tools

import (
	"bytes"
	"log/slog"
	"slices"
	"strings"
	"text/template"
	"text/template/parse"
)

// parseTag finds the tool calling tag from a Go template
// often <tool_call> [TOOL_CALL] or similar by finding the
// first text node after .ToolCalls and returning the content
// if no tag is found, return "{" to indicate that json objects
// should be attempted to be parsed as tool calls
func parseTag(tmpl *template.Template) string {
	if tmpl == nil || tmpl.Tree == nil {
		slog.Debug("template or tree is nil")
		return "{"
	}

	tc := findToolCallNode(tmpl.Tree.Root.Nodes)
	if tc == nil {
		return "{"
	}

	tn := findTextNode(tc.List.Nodes)
	if tn == nil {
		return "{"
	}

	tag := string(tn.Text)
	tag = strings.ReplaceAll(tag, "\r\n", "\n")

	// avoid parsing { onwards as this may be a tool call
	// however keep '{' as a prefix if there is no tag
	// so that all json objects will be attempted to
	// be parsed as tool calls
	tag, _, _ = strings.Cut(tag, "{")
	tag = strings.TrimSpace(tag)
	if tag == "" {
		tag = "{"
	}

	return tag
}

// findToolCallNode searches for and returns an IfNode with .ToolCalls
func findToolCallNode(nodes []parse.Node) *parse.IfNode {
	isToolCallsNode := func(n *parse.IfNode) bool {
		for _, cmd := range n.Pipe.Cmds {
			for _, arg := range cmd.Args {
				if field, ok := arg.(*parse.FieldNode); ok {
					if slices.Contains(field.Ident, "ToolCalls") {
						return true
					}
				}
			}
		}
		return false
	}

	for _, node := range nodes {
		switch n := node.(type) {
		case *parse.IfNode:
			if isToolCallsNode(n) {
				return n
			}
			// Recursively search in nested IfNodes
			if result := findToolCallNode(n.List.Nodes); result != nil {
				return result
			}
			if n.ElseList != nil {
				if result := findToolCallNode(n.ElseList.Nodes); result != nil {
					return result
				}
			}
		case *parse.ListNode:
			if result := findToolCallNode(n.Nodes); result != nil {
				return result
			}
		case *parse.RangeNode:
			if result := findToolCallNode(n.List.Nodes); result != nil {
				return result
			}
			if n.ElseList != nil {
				if result := findToolCallNode(n.ElseList.Nodes); result != nil {
					return result
				}
			}
		case *parse.WithNode:
			if result := findToolCallNode(n.List.Nodes); result != nil {
				return result
			}
			if n.ElseList != nil {
				if result := findToolCallNode(n.ElseList.Nodes); result != nil {
					return result
				}
			}
		}
	}
	return nil
}

// findTextNode does a depth-first search for the first text content in nodes,
// stopping at template constructs to avoid parsing text after the tool calls
func findTextNode(nodes []parse.Node) *parse.TextNode {
	for _, node := range nodes {
		switch n := node.(type) {
		case *parse.TextNode:
			// skip whitespace-only text nodes
			if len(bytes.TrimSpace(n.Text)) == 0 {
				continue
			}
			return n
		case *parse.IfNode:
			if text := findTextNode(n.List.Nodes); text != nil {
				return text
			}
			if n.ElseList != nil {
				if text := findTextNode(n.ElseList.Nodes); text != nil {
					return text
				}
			}
			return nil
		case *parse.ListNode:
			if text := findTextNode(n.Nodes); text != nil {
				return text
			}
		case *parse.RangeNode:
			if text := findTextNode(n.List.Nodes); text != nil {
				return text
			}
			if n.ElseList != nil {
				if text := findTextNode(n.ElseList.Nodes); text != nil {
					return text
				}
			}
			return nil
		case *parse.WithNode:
			if text := findTextNode(n.List.Nodes); text != nil {
				return text
			}
			if n.ElseList != nil {
				if text := findTextNode(n.ElseList.Nodes); text != nil {
					return text
				}
			}
			return nil
		case *parse.ActionNode:
			return nil
		}
	}
	return nil
}