Commit 03e40efa authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

runner.go: Merge partial unicode characters before sending

We check for partial unicode characters and accumulate them before
sending. However, when we did send, we still sent each individual piece
separately, leading to broken output. This combines everything into
a single group, which is also more efficient.

This also switches to the built-in check for valid unicode characters,
which is stricter. After this, we should never send back an invalid
sequence.

Fixes #7290
parent 23f74650
...@@ -30,6 +30,22 @@ func TestOrcaMiniBlueSky(t *testing.T) { ...@@ -30,6 +30,22 @@ func TestOrcaMiniBlueSky(t *testing.T) {
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"}) GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
} }
func TestUnicodeOutput(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
// Set up the test data
req := api.GenerateRequest{
Model: "gemma2:2b",
Prompt: "Output some smily face emoji",
Stream: &stream,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
GenerateTestHelper(ctx, t, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"})
}
func TestUnicodeModelDir(t *testing.T) { func TestUnicodeModelDir(t *testing.T) {
// This is only useful for Windows with utf-16 characters, so skip this test for other platforms // This is only useful for Windows with utf-16 characters, so skip this test for other platforms
if runtime.GOOS != "windows" { if runtime.GOOS != "windows" {
......
...@@ -18,6 +18,7 @@ import ( ...@@ -18,6 +18,7 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"unicode/utf8"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
...@@ -293,17 +294,29 @@ func (s *Server) shiftContext(seq *Sequence) { ...@@ -293,17 +294,29 @@ func (s *Server) shiftContext(seq *Sequence) {
} }
func flushPending(seq *Sequence) bool { func flushPending(seq *Sequence) bool {
for _, p := range seq.pendingResponses { joined := strings.Join(seq.pendingResponses, "")
select {
case seq.responses <- p:
case <-seq.quit:
seq.pendingResponses = []string{} seq.pendingResponses = []string{}
return false
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) {
joined = joined[:len(joined)-1]
} }
if len(joined) == 0 {
return true
} }
seq.pendingResponses = []string{} select {
case seq.responses <- joined:
return true return true
case <-seq.quit:
return false
}
} }
func (s *Server) removeSequence(seqIndex int, reason string) { func (s *Server) removeSequence(seqIndex int, reason string) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment