Commit 184ad8f0 authored by Bruce MacDonald's avatar Bruce MacDonald
Browse files

allow specifying stop conditions in modelfile

parent 822a0e36
...@@ -178,7 +178,7 @@ type Options struct { ...@@ -178,7 +178,7 @@ type Options struct {
MirostatTau float32 `json:"mirostat_tau,omitempty"` MirostatTau float32 `json:"mirostat_tau,omitempty"`
MirostatEta float32 `json:"mirostat_eta,omitempty"` MirostatEta float32 `json:"mirostat_eta,omitempty"`
PenalizeNewline bool `json:"penalize_newline,omitempty"` PenalizeNewline bool `json:"penalize_newline,omitempty"`
StopConditions []string `json:"stop_conditions,omitempty"` Stop []string `json:"stop,omitempty"`
NumThread int `json:"num_thread,omitempty"` NumThread int `json:"num_thread,omitempty"`
} }
......
...@@ -246,7 +246,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) ...@@ -246,7 +246,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
} }
func (llm *LLM) checkStopConditions(b bytes.Buffer) error { func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
for _, stopCondition := range llm.StopConditions { for _, stopCondition := range llm.Stop {
if stopCondition == b.String() { if stopCondition == b.String() {
return io.EOF return io.EOF
} else if strings.HasPrefix(stopCondition, b.String()) { } else if strings.HasPrefix(stopCondition, b.String()) {
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"path" "path"
"path/filepath" "path/filepath"
"reflect" "reflect"
"regexp"
"strconv" "strconv"
"strings" "strings"
"text/template" "text/template"
...@@ -472,6 +473,14 @@ func paramsToReader(params map[string]string) (io.ReadSeeker, error) { ...@@ -472,6 +473,14 @@ func paramsToReader(params map[string]string) (io.ReadSeeker, error) {
field.SetBool(boolVal) field.SetBool(boolVal)
case reflect.String: case reflect.String:
field.SetString(val) field.SetString(val)
case reflect.Slice:
re := regexp.MustCompile(`"(.*?)"`) // matches everything enclosed in quotes
vals := re.FindAllStringSubmatch(val, -1)
var sliceVal []string
for _, v := range vals {
sliceVal = append(sliceVal, v[1]) // v[1] is the captured group, v[0] is the entire match
}
field.Set(reflect.ValueOf(sliceVal))
default: default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key) return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment