ministral.go 3.34 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package parsers

import (
	"encoding/json"
	"fmt"
	"strings"

	"github.com/ollama/ollama/api"
)

type ministralParserState int

const (
	ministralCollectingContent = iota
	ministralCollectingThinkingContent
	ministralCollectingToolName
	ministralCollectingToolArgs
)

type MinistralParser struct {
	state              ministralParserState
	buffer             strings.Builder
	tools              []api.Tool
	hasThinkingSupport bool
	currentTool        *api.Tool
}

func (p *MinistralParser) HasToolSupport() bool {
	return true
}

func (p *MinistralParser) HasThinkingSupport() bool {
	return p.hasThinkingSupport
}

func (p *MinistralParser) setInitialState(lastMessage *api.Message) {
	prefill := lastMessage != nil && lastMessage.Role == "assistant"
	if !p.HasThinkingSupport() {
		p.state = ministralCollectingContent
		return
	}

	if prefill && lastMessage.Content != "" {
		p.state = ministralCollectingContent
		return
	}

	p.state = ministralCollectingThinkingContent
}

func (p *MinistralParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
	p.tools = tools
	p.setInitialState(lastMessage)
	return tools
}

func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
	for i := range tools {
		if tools[i].Function.Name == n {
			return &tools[i], nil
		}
	}
	return nil, fmt.Errorf("tool '%s' not found", n)
}

func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
	p.buffer.WriteString(s)

	switch p.state {
	case ministralCollectingContent:
		if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
			before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
			if before != "" {
				return before, "", calls, nil
			}
			p.state = ministralCollectingToolName
		} else if strings.Contains(p.buffer.String(), "[THINK]") {
			p.state = ministralCollectingThinkingContent
			return "", "", calls, nil
		} else {
			p.buffer.Reset()
			return s, "", calls, nil
		}
	case ministralCollectingThinkingContent:
		if strings.Contains(p.buffer.String(), "[/THINK]") {
			thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
			p.state = ministralCollectingContent
			if after != "" {
				p.buffer.Reset()
				return after, thinkingContent, calls, nil
			}
			return "", thinkingContent, calls, nil
		} else {
			p.buffer.Reset()
			return "", s, calls, nil
		}
	case ministralCollectingToolName:
		if strings.Contains(p.buffer.String(), "[ARGS]") {
			name, _ := splitAtTag(&p.buffer, "[ARGS]", false)

			t, err := toolByName(p.tools, name)
			if err != nil {
				return "", "", calls, err
			}
			p.currentTool = t
			p.state = ministralCollectingToolArgs
			return "", "", calls, nil
		}
		return "", "", calls, nil
	case ministralCollectingToolArgs:
		if strings.Contains(p.buffer.String(), "}") {
			before, _ := splitAtTag(&p.buffer, "}", false)
			before += "}"

			var data map[string]any
			if err := json.Unmarshal([]byte(before), &data); err != nil {
				// todo - throw a better error
				return "", "", calls, err
			}

			p.state = ministralCollectingContent

			call := api.ToolCall{
				Function: api.ToolCallFunction{
					Name:      p.currentTool.Function.Name,
					Arguments: api.ToolCallFunctionArguments(data),
				},
			}
			calls = append(calls, call)
			return "", "", calls, nil
		}
		return "", "", calls, nil
	}

	return p.buffer.String(), thinking, calls, nil
}