ministral.go 3.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
package parsers

import (
	"encoding/json"
	"fmt"
	"strings"

	"github.com/ollama/ollama/api"
)

type ministralParserState int

const (
	ministralCollectingContent = iota
	ministralCollectingThinkingContent
	ministralCollectingToolName
	ministralCollectingToolArgs
)

type MinistralParser struct {
	state              ministralParserState
	buffer             strings.Builder
	tools              []api.Tool
	hasThinkingSupport bool
	currentTool        *api.Tool
}

func (p *MinistralParser) HasToolSupport() bool {
	return true
}

func (p *MinistralParser) HasThinkingSupport() bool {
	return p.hasThinkingSupport
}

func (p *MinistralParser) setInitialState(lastMessage *api.Message) {
	prefill := lastMessage != nil && lastMessage.Role == "assistant"
	if !p.HasThinkingSupport() {
		p.state = ministralCollectingContent
		return
	}

	if prefill && lastMessage.Content != "" {
		p.state = ministralCollectingContent
		return
	}

	p.state = ministralCollectingThinkingContent
}

func (p *MinistralParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
	p.tools = tools
	p.setInitialState(lastMessage)
	return tools
}

func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
	for i := range tools {
		if tools[i].Function.Name == n {
			return &tools[i], nil
		}
	}
	return nil, fmt.Errorf("tool '%s' not found", n)
}

func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
	p.buffer.WriteString(s)

	switch p.state {
	case ministralCollectingContent:
		if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
			before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
			if before != "" {
				return before, "", calls, nil
			}
			p.state = ministralCollectingToolName
		} else if strings.Contains(p.buffer.String(), "[THINK]") {
			p.state = ministralCollectingThinkingContent
			return "", "", calls, nil
		} else {
			p.buffer.Reset()
			return s, "", calls, nil
		}
	case ministralCollectingThinkingContent:
		if strings.Contains(p.buffer.String(), "[/THINK]") {
			thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
			p.state = ministralCollectingContent
			if after != "" {
				p.buffer.Reset()
				return after, thinkingContent, calls, nil
			}
			return "", thinkingContent, calls, nil
		} else {
			p.buffer.Reset()
			return "", s, calls, nil
		}
	case ministralCollectingToolName:
		if strings.Contains(p.buffer.String(), "[ARGS]") {
			name, _ := splitAtTag(&p.buffer, "[ARGS]", false)

			t, err := toolByName(p.tools, name)
			if err != nil {
				return "", "", calls, err
			}
			p.currentTool = t
			p.state = ministralCollectingToolArgs
			return "", "", calls, nil
		}
		return "", "", calls, nil
	case ministralCollectingToolArgs:
		if strings.Contains(p.buffer.String(), "}") {
			before, _ := splitAtTag(&p.buffer, "}", false)
			before += "}"

115
116
			var args api.ToolCallFunctionArguments
			if err := json.Unmarshal([]byte(before), &args); err != nil {
117
118
119
120
121
122
123
124
125
				// todo - throw a better error
				return "", "", calls, err
			}

			p.state = ministralCollectingContent

			call := api.ToolCall{
				Function: api.ToolCallFunction{
					Name:      p.currentTool.Function.Name,
126
					Arguments: args,
127
128
129
130
131
132
133
134
135
136
				},
			}
			calls = append(calls, call)
			return "", "", calls, nil
		}
		return "", "", calls, nil
	}

	return p.buffer.String(), thinking, calls, nil
}