Unverified Commit 8fa477fa authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

Merge pull request #225 from jmorganca/stop-conditions

add stop conditions
parents 01d155c9 fadf75f9
...@@ -178,6 +178,7 @@ type Options struct { ...@@ -178,6 +178,7 @@ type Options struct {
MirostatTau float32 `json:"mirostat_tau,omitempty"` MirostatTau float32 `json:"mirostat_tau,omitempty"`
MirostatEta float32 `json:"mirostat_eta,omitempty"` MirostatEta float32 `json:"mirostat_eta,omitempty"`
PenalizeNewline bool `json:"penalize_newline,omitempty"` PenalizeNewline bool `json:"penalize_newline,omitempty"`
StopConditions []string `json:"stop_conditions,omitempty"`
NumThread int `json:"num_thread,omitempty"` NumThread int `json:"num_thread,omitempty"`
} }
......
...@@ -172,6 +172,8 @@ func (llm *LLM) Close() { ...@@ -172,6 +172,8 @@ func (llm *LLM) Close() {
C.llama_print_timings(llm.ctx) C.llama_print_timings(llm.ctx)
} }
var errNeedMoreData = errors.New("need more data")
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
C.llama_reset_timings(llm.ctx) C.llama_reset_timings(llm.ctx)
...@@ -200,6 +202,17 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) ...@@ -200,6 +202,17 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
} }
b.WriteString(llm.detokenize(token)) b.WriteString(llm.detokenize(token))
if err := llm.checkStopConditions(b); err != nil {
if errors.Is(err, io.EOF) {
break
} else if errors.Is(err, errNeedMoreData) {
continue
}
return err
}
if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax { if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
fn(api.GenerateResponse{Response: b.String()}) fn(api.GenerateResponse{Response: b.String()})
b.Reset() b.Reset()
...@@ -228,6 +241,18 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) ...@@ -228,6 +241,18 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
return nil return nil
} }
func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
for _, stopCondition := range llm.StopConditions {
if stopCondition == b.String() {
return io.EOF
} else if strings.HasPrefix(stopCondition, b.String()) {
return errNeedMoreData
}
}
return nil
}
func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token { func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
tokens := append(ctx, llm.tokenize(prompt)...) tokens := append(ctx, llm.tokenize(prompt)...)
if llm.NumKeep < 0 { if llm.NumKeep < 0 {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment