Commit c0a00f68 authored by Michael Yang's avatar Michael Yang
Browse files

refactor modelfile parser

parent f0c454ab
...@@ -6,8 +6,9 @@ import ( ...@@ -6,8 +6,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"slices" "slices"
"strconv"
"strings"
) )
type Command struct { type Command struct {
...@@ -15,118 +16,219 @@ type Command struct { ...@@ -15,118 +16,219 @@ type Command struct {
Args string Args string
} }
func (c *Command) Reset() { type state int
c.Name = ""
c.Args = ""
}
func Parse(reader io.Reader) ([]Command, error) { const (
var commands []Command stateNil state = iota
var command, modelCommand Command stateName
stateValue
stateParameter
stateMessage
stateComment
)
scanner := bufio.NewScanner(reader) var errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"")
scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize)
scanner.Split(scanModelfile) func Parse(r io.Reader) (cmds []Command, err error) {
for scanner.Scan() { var cmd Command
line := scanner.Bytes() var curr state
var b bytes.Buffer
var role string
br := bufio.NewReader(r)
for {
r, _, err := br.ReadRune()
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return nil, err
}
fields := bytes.SplitN(line, []byte(" "), 2) next, r, err := parseRuneForState(r, curr)
if len(fields) == 0 || len(fields[0]) == 0 { if errors.Is(err, io.ErrUnexpectedEOF) {
continue return nil, fmt.Errorf("%w: %s", err, b.String())
} else if err != nil {
return nil, err
} }
switch string(bytes.ToUpper(fields[0])) { if next != curr {
case "FROM": switch curr {
command.Name = "model" case stateName, stateParameter:
command.Args = string(bytes.TrimSpace(fields[1])) switch s := strings.ToLower(b.String()); s {
// copy command for validation case "from":
modelCommand = command cmd.Name = "model"
case "ADAPTER": case "parameter":
command.Name = string(bytes.ToLower(fields[0])) next = stateParameter
command.Args = string(bytes.TrimSpace(fields[1])) case "message":
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT": next = stateMessage
command.Name = string(bytes.ToLower(fields[0])) fallthrough
command.Args = string(fields[1]) default:
case "PARAMETER": cmd.Name = s
fields = bytes.SplitN(fields[1], []byte(" "), 2) }
if len(fields) < 2 { case stateMessage:
return nil, fmt.Errorf("missing value for %s", fields) if !slices.Contains([]string{"system", "user", "assistant"}, b.String()) {
return nil, errInvalidRole
}
role = b.String()
case stateComment, stateNil:
// pass
case stateValue:
s := b.String()
s, ok := unquote(b.String())
if !ok || isSpace(r) {
if _, err := b.WriteRune(r); err != nil {
return nil, err
}
continue
}
if role != "" {
s = role + ": " + s
role = ""
}
cmd.Args = s
cmds = append(cmds, cmd)
} }
command.Name = string(fields[0]) b.Reset()
command.Args = string(bytes.TrimSpace(fields[1])) curr = next
case "EMBED": }
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
case "MESSAGE": if strconv.IsPrint(r) {
command.Name = string(bytes.ToLower(fields[0])) if _, err := b.WriteRune(r); err != nil {
fields = bytes.SplitN(fields[1], []byte(" "), 2) return nil, err
if len(fields) < 2 {
return nil, fmt.Errorf("should be in the format <role> <message>")
}
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
}
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
default:
if !bytes.HasPrefix(fields[0], []byte("#")) {
// log a warning for unknown commands
slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
} }
continue }
}
// flush the buffer
switch curr {
case stateComment, stateNil:
// pass; nothing to flush
case stateValue:
if _, ok := unquote(b.String()); !ok {
return nil, io.ErrUnexpectedEOF
} }
commands = append(commands, command) cmd.Args = b.String()
command.Reset() cmds = append(cmds, cmd)
default:
return nil, io.ErrUnexpectedEOF
} }
if modelCommand.Args == "" { for _, cmd := range cmds {
return nil, errors.New("no FROM line for the model was specified") if cmd.Name == "model" {
return cmds, nil
}
} }
return commands, scanner.Err() return nil, errors.New("no FROM line")
} }
func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) { func parseRuneForState(r rune, cs state) (state, rune, error) {
advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF) switch cs {
if err != nil { case stateNil:
return 0, nil, err switch {
case r == '#':
return stateComment, 0, nil
case isSpace(r), isNewline(r):
return stateNil, 0, nil
default:
return stateName, r, nil
}
case stateName:
switch {
case isAlpha(r):
return stateName, r, nil
case isSpace(r):
return stateValue, 0, nil
default:
return stateNil, 0, errors.New("invalid")
}
case stateValue:
switch {
case isNewline(r):
return stateNil, r, nil
case isSpace(r):
return stateNil, r, nil
default:
return stateValue, r, nil
}
case stateParameter:
switch {
case isAlpha(r), isNumber(r), r == '_':
return stateParameter, r, nil
case isSpace(r):
return stateValue, 0, nil
default:
return stateNil, 0, io.ErrUnexpectedEOF
}
case stateMessage:
switch {
case isAlpha(r):
return stateMessage, r, nil
case isSpace(r):
return stateValue, 0, nil
default:
return stateNil, 0, io.ErrUnexpectedEOF
}
case stateComment:
switch {
case isNewline(r):
return stateNil, 0, nil
default:
return stateComment, 0, nil
}
default:
return stateNil, 0, errors.New("")
} }
}
if advance > 0 && token != nil { func unquote(s string) (string, bool) {
return advance, token, nil if len(s) == 0 {
return "", false
} }
advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF) // TODO: single quotes
if err != nil { if len(s) >= 3 && s[:3] == `"""` {
return 0, nil, err if len(s) >= 6 && s[len(s)-3:] == `"""` {
return s[3 : len(s)-3], true
}
return "", false
} }
if advance > 0 && token != nil { if len(s) >= 1 && s[0] == '"' {
return advance, token, nil if len(s) >= 2 && s[len(s)-1] == '"' {
return s[1 : len(s)-1], true
}
return "", false
} }
return bufio.ScanLines(data, atEOF) return s, true
} }
func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) { func isAlpha(r rune) bool {
newline := bytes.IndexByte(data, '\n') return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z'
}
if start := bytes.Index(data, openBytes); start >= 0 && start < newline { func isNumber(r rune) bool {
end := bytes.Index(data[start+len(openBytes):], closeBytes) return r >= '0' && r <= '9'
if end < 0 { }
if atEOF {
return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes)
} else {
return 0, nil, nil
}
}
n := start + len(openBytes) + end + len(closeBytes) func isSpace(r rune) bool {
return r == ' ' || r == '\t'
}
newData := data[:start] func isNewline(r rune) bool {
newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...) return r == '\r' || r == '\n'
return n, newData, nil }
}
return 0, nil, nil func isValidRole(role string) bool {
return role == "system" || role == "user" || role == "assistant"
} }
package parser package parser
import ( import (
"bytes"
"fmt"
"io"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_Parser(t *testing.T) { func TestParser(t *testing.T) {
input := ` input := `
FROM model1 FROM model1
...@@ -35,7 +38,7 @@ TEMPLATE template1 ...@@ -35,7 +38,7 @@ TEMPLATE template1
assert.Equal(t, expectedCommands, commands) assert.Equal(t, expectedCommands, commands)
} }
func Test_Parser_NoFromLine(t *testing.T) { func TestParserNoFromLine(t *testing.T) {
input := ` input := `
PARAMETER param1 value1 PARAMETER param1 value1
...@@ -48,7 +51,7 @@ PARAMETER param2 value2 ...@@ -48,7 +51,7 @@ PARAMETER param2 value2
assert.ErrorContains(t, err, "no FROM line") assert.ErrorContains(t, err, "no FROM line")
} }
func Test_Parser_MissingValue(t *testing.T) { func TestParserParametersMissingValue(t *testing.T) {
input := ` input := `
FROM foo FROM foo
...@@ -58,41 +61,292 @@ PARAMETER param1 ...@@ -58,41 +61,292 @@ PARAMETER param1
reader := strings.NewReader(input) reader := strings.NewReader(input)
_, err := Parse(reader) _, err := Parse(reader)
assert.ErrorContains(t, err, "missing value for [param1]") assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
} }
func Test_Parser_Messages(t *testing.T) { func TestParserMessages(t *testing.T) {
var cases = []struct {
input := ` input string
expected []Command
err error
}{
{
`
FROM foo
MESSAGE system You are a Parser. Always Parse things.
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
},
nil,
},
{
`
FROM foo FROM foo
MESSAGE system You are a Parser. Always Parse things. MESSAGE system You are a Parser. Always Parse things.
MESSAGE user Hey there! MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things! MESSAGE assistant Hello, I want to parse all the things!
` `,
[]Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
{Name: "message", Args: "user: Hey there!"},
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
},
nil,
},
{
`
FROM foo
MESSAGE system """
You are a multiline Parser. Always Parse things.
"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: \nYou are a multiline Parser. Always Parse things.\n"},
},
nil,
},
{
`
FROM foo
MESSAGE badguy I'm a bad guy!
`,
nil,
errInvalidRole,
},
{
`
FROM foo
MESSAGE system
`,
nil,
io.ErrUnexpectedEOF,
},
{
`
FROM foo
MESSAGE system`,
nil,
io.ErrUnexpectedEOF,
},
}
reader := strings.NewReader(input) for _, c := range cases {
commands, err := Parse(reader) t.Run("", func(t *testing.T) {
assert.Nil(t, err) commands, err := Parse(strings.NewReader(c.input))
assert.ErrorIs(t, err, c.err)
assert.Equal(t, c.expected, commands)
})
}
}
expectedCommands := []Command{ func TestParserQuoted(t *testing.T) {
{Name: "model", Args: "foo"}, var cases = []struct {
{Name: "message", Args: "system: You are a Parser. Always Parse things."}, multiline string
{Name: "message", Args: "user: Hey there!"}, expected []Command
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, err error
}{
{
`
FROM foo
TEMPLATE """
This is a
multiline template.
"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: "\nThis is a\nmultiline template.\n"},
},
nil,
},
{
`
FROM foo
TEMPLATE """
This is a
multiline template."""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: "\nThis is a\nmultiline template."},
},
nil,
},
{
`
FROM foo
TEMPLATE """This is a
multiline template."""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: "This is a\nmultiline template."},
},
nil,
},
{
`
FROM foo
TEMPLATE """This is a multiline template."""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: "This is a multiline template."},
},
nil,
},
{
`
FROM foo
TEMPLATE """This is a multiline template.""
`,
nil,
io.ErrUnexpectedEOF,
},
{
`
FROM foo
TEMPLATE "
`,
nil,
io.ErrUnexpectedEOF,
},
{
`
FROM foo
TEMPLATE """
This is a multiline template with "quotes".
"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: "\nThis is a multiline template with \"quotes\".\n"},
},
nil,
},
{
`
FROM foo
TEMPLATE """"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: ""},
},
nil,
},
{
`
FROM foo
TEMPLATE ""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: ""},
},
nil,
},
{
`
FROM foo
TEMPLATE "'"
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: "'"},
},
nil,
},
} }
assert.Equal(t, expectedCommands, commands) for _, c := range cases {
t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c.multiline))
assert.ErrorIs(t, err, c.err)
assert.Equal(t, c.expected, commands)
})
}
} }
func Test_Parser_Messages_BadRole(t *testing.T) { func TestParserParameters(t *testing.T) {
var cases = []string{
"numa true",
"num_ctx 1",
"num_batch 1",
"num_gqa 1",
"num_gpu 1",
"main_gpu 1",
"low_vram true",
"f16_kv true",
"logits_all true",
"vocab_only true",
"use_mmap true",
"use_mlock true",
"num_thread 1",
"num_keep 1",
"seed 1",
"num_predict 1",
"top_k 1",
"top_p 1.0",
"tfs_z 1.0",
"typical_p 1.0",
"repeat_last_n 1",
"temperature 1.0",
"repeat_penalty 1.0",
"presence_penalty 1.0",
"frequency_penalty 1.0",
"mirostat 1",
"mirostat_tau 1.0",
"mirostat_eta 1.0",
"penalize_newline true",
"stop foo",
}
input := ` for _, c := range cases {
t.Run(c, func(t *testing.T) {
var b bytes.Buffer
fmt.Fprintln(&b, "FROM foo")
fmt.Fprintln(&b, "PARAMETER", c)
t.Logf("input: %s", b.String())
_, err := Parse(&b)
assert.Nil(t, err)
})
}
}
func TestParserOnlyFrom(t *testing.T) {
commands, err := Parse(strings.NewReader("FROM foo"))
assert.Nil(t, err)
expected := []Command{{Name: "model", Args: "foo"}}
assert.Equal(t, expected, commands)
}
func TestParserComments(t *testing.T) {
var cases = []struct {
input string
expected []Command
}{
{
`
# comment
FROM foo FROM foo
MESSAGE badguy I'm a bad guy! `,
` []Command{
{Name: "model", Args: "foo"},
},
},
}
reader := strings.NewReader(input) for _, c := range cases {
_, err := Parse(reader) t.Run("", func(t *testing.T) {
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"") commands, err := Parse(strings.NewReader(c.input))
assert.Nil(t, err)
assert.Equal(t, c.expected, commands)
})
}
} }
...@@ -238,6 +238,5 @@ func Test_Routes(t *testing.T) { ...@@ -238,6 +238,5 @@ func Test_Routes(t *testing.T) {
if tc.Expected != nil { if tc.Expected != nil {
tc.Expected(t, resp) tc.Expected(t, resp)
} }
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment