Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
6544e147
"vscode:/vscode.git/clone" did not exist on "30fd8c17d8253310ff3a69636e5c9ce8982e3cf5"
Unverified
Commit
6544e147
authored
Oct 11, 2025
by
Jeffrey Morgan
Committed by
GitHub
Oct 11, 2025
Browse files
Reapply "add truncate and shift parameters" (#12582)
parent
5db8a818
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
298 additions
and
57 deletions
+298
-57
api/types.go
api/types.go
+16
-0
llm/server.go
llm/server.go
+4
-2
runner/llamarunner/runner.go
runner/llamarunner/runner.go
+22
-0
runner/ollamarunner/runner.go
runner/ollamarunner/runner.go
+25
-0
server/prompt.go
server/prompt.go
+2
-2
server/prompt_test.go
server/prompt_test.go
+64
-38
server/routes.go
server/routes.go
+64
-15
server/routes_generate_test.go
server/routes_generate_test.go
+101
-0
No files found.
api/types.go
View file @
6544e147
...
@@ -106,6 +106,14 @@ type GenerateRequest struct {
...
@@ -106,6 +106,14 @@ type GenerateRequest struct {
// before this option was introduced)
// before this option was introduced)
Think
*
ThinkValue
`json:"think,omitempty"`
Think
*
ThinkValue
`json:"think,omitempty"`
// Truncate is a boolean that, when set to true, truncates the chat history messages
// if the rendered prompt exceeds the context length limit.
Truncate
*
bool
`json:"truncate,omitempty"`
// Shift is a boolean that, when set to true, shifts the chat history
// when hitting the context length limit instead of erroring.
Shift
*
bool
`json:"shift,omitempty"`
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model.
// template instead of calling the model.
DebugRenderOnly
bool
`json:"_debug_render_only,omitempty"`
DebugRenderOnly
bool
`json:"_debug_render_only,omitempty"`
...
@@ -140,6 +148,14 @@ type ChatRequest struct {
...
@@ -140,6 +148,14 @@ type ChatRequest struct {
// for supported models.
// for supported models.
Think
*
ThinkValue
`json:"think,omitempty"`
Think
*
ThinkValue
`json:"think,omitempty"`
// Truncate is a boolean that, when set to true, truncates the chat history messages
// if the rendered prompt exceeds the context length limit.
Truncate
*
bool
`json:"truncate,omitempty"`
// Shift is a boolean that, when set to true, shifts the chat history
// when hitting the context length limit instead of erroring.
Shift
*
bool
`json:"shift,omitempty"`
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model.
// template instead of calling the model.
DebugRenderOnly
bool
`json:"_debug_render_only,omitempty"`
DebugRenderOnly
bool
`json:"_debug_render_only,omitempty"`
...
...
llm/server.go
View file @
6544e147
...
@@ -1379,7 +1379,9 @@ type CompletionRequest struct {
...
@@ -1379,7 +1379,9 @@ type CompletionRequest struct {
Images
[]
ImageData
Images
[]
ImageData
Options
*
api
.
Options
Options
*
api
.
Options
Grammar
string
// set before sending the request to the subprocess
Grammar
string
// set before sending the request to the subprocess
Shift
bool
Truncate
bool
}
}
// DoneReason represents the reason why a completion response is done
// DoneReason represents the reason why a completion response is done
...
@@ -1501,7 +1503,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
...
@@ -1501,7 +1503,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return
fmt
.
Errorf
(
"failed reading llm error response: %w"
,
err
)
return
fmt
.
Errorf
(
"failed reading llm error response: %w"
,
err
)
}
}
log
.
Printf
(
"llm predict error: %s"
,
bodyBytes
)
log
.
Printf
(
"llm predict error: %s"
,
bodyBytes
)
return
fmt
.
Errorf
(
"%s"
,
bodyBytes
)
return
api
.
StatusError
{
StatusCode
:
res
.
StatusCode
,
ErrorMessage
:
strings
.
TrimSpace
(
string
(
bodyBytes
)
)}
}
}
scanner
:=
bufio
.
NewScanner
(
res
.
Body
)
scanner
:=
bufio
.
NewScanner
(
res
.
Body
)
...
...
runner/llamarunner/runner.go
View file @
6544e147
...
@@ -79,6 +79,9 @@ type Sequence struct {
...
@@ -79,6 +79,9 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
// true if an embedding are to be returned instead of text generation
embeddingOnly
bool
embeddingOnly
bool
// shift if context window is exceeded
shift
bool
doneReason
llm
.
DoneReason
doneReason
llm
.
DoneReason
// Metrics
// Metrics
...
@@ -94,8 +97,12 @@ type NewSequenceParams struct {
...
@@ -94,8 +97,12 @@ type NewSequenceParams struct {
numKeep
int
numKeep
int
samplingParams
*
llama
.
SamplingParams
samplingParams
*
llama
.
SamplingParams
embedding
bool
embedding
bool
shift
bool
truncate
bool
}
}
var
errorInputTooLong
=
errors
.
New
(
"the input length exceeds the context length"
)
func
(
s
*
Server
)
NewSequence
(
prompt
string
,
images
[]
llm
.
ImageData
,
params
NewSequenceParams
)
(
*
Sequence
,
error
)
{
func
(
s
*
Server
)
NewSequence
(
prompt
string
,
images
[]
llm
.
ImageData
,
params
NewSequenceParams
)
(
*
Sequence
,
error
)
{
s
.
ready
.
Wait
()
s
.
ready
.
Wait
()
...
@@ -119,6 +126,10 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
...
@@ -119,6 +126,10 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
if
len
(
inputs
)
>
s
.
cache
.
numCtx
{
if
len
(
inputs
)
>
s
.
cache
.
numCtx
{
discard
:=
len
(
inputs
)
-
s
.
cache
.
numCtx
discard
:=
len
(
inputs
)
-
s
.
cache
.
numCtx
if
!
params
.
truncate
{
return
nil
,
errorInputTooLong
}
newInputs
:=
inputs
[
:
params
.
numKeep
]
newInputs
:=
inputs
[
:
params
.
numKeep
]
newInputs
=
append
(
newInputs
,
inputs
[
params
.
numKeep
+
discard
:
]
...
)
newInputs
=
append
(
newInputs
,
inputs
[
params
.
numKeep
+
discard
:
]
...
)
...
@@ -385,6 +396,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
...
@@ -385,6 +396,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
for
i
,
input
:=
range
seq
.
inputs
{
for
i
,
input
:=
range
seq
.
inputs
{
if
len
(
seq
.
cache
.
Inputs
)
+
len
(
seq
.
pendingInputs
)
+
1
>
s
.
cache
.
numCtx
{
if
len
(
seq
.
cache
.
Inputs
)
+
len
(
seq
.
pendingInputs
)
+
1
>
s
.
cache
.
numCtx
{
if
len
(
seq
.
pendingInputs
)
==
0
{
if
len
(
seq
.
pendingInputs
)
==
0
{
if
!
seq
.
shift
{
s
.
removeSequence
(
seqIdx
,
llm
.
DoneReasonLength
)
break
}
err
:=
s
.
cache
.
ShiftCacheSlot
(
seq
.
cache
,
seq
.
numKeep
)
err
:=
s
.
cache
.
ShiftCacheSlot
(
seq
.
cache
,
seq
.
numKeep
)
if
err
!=
nil
{
if
err
!=
nil
{
var
reprocess
*
ErrReprocessInputs
var
reprocess
*
ErrReprocessInputs
...
@@ -583,8 +599,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
...
@@ -583,8 +599,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
numKeep
:
req
.
Options
.
NumKeep
,
numKeep
:
req
.
Options
.
NumKeep
,
samplingParams
:
&
samplingParams
,
samplingParams
:
&
samplingParams
,
embedding
:
false
,
embedding
:
false
,
shift
:
req
.
Shift
,
truncate
:
req
.
Truncate
,
})
})
if
err
!=
nil
{
if
err
!=
nil
{
if
errors
.
Is
(
err
,
errorInputTooLong
)
{
http
.
Error
(
w
,
err
.
Error
(),
http
.
StatusBadRequest
)
return
}
http
.
Error
(
w
,
fmt
.
Sprintf
(
"Failed to create new sequence: %v"
,
err
),
http
.
StatusInternalServerError
)
http
.
Error
(
w
,
fmt
.
Sprintf
(
"Failed to create new sequence: %v"
,
err
),
http
.
StatusInternalServerError
)
return
return
}
}
...
...
runner/ollamarunner/runner.go
View file @
6544e147
...
@@ -88,6 +88,9 @@ type Sequence struct {
...
@@ -88,6 +88,9 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
// true if an embedding are to be returned instead of text generation
embeddingOnly
bool
embeddingOnly
bool
// shift if context window is exceeded
shift
bool
doneReason
llm
.
DoneReason
doneReason
llm
.
DoneReason
// Metrics
// Metrics
...
@@ -104,8 +107,12 @@ type NewSequenceParams struct {
...
@@ -104,8 +107,12 @@ type NewSequenceParams struct {
numKeep
int32
numKeep
int32
sampler
sample
.
Sampler
sampler
sample
.
Sampler
embedding
bool
embedding
bool
shift
bool
truncate
bool
}
}
var
errorInputTooLong
=
errors
.
New
(
"the input length exceeds the context length"
)
func
(
s
*
Server
)
NewSequence
(
prompt
string
,
images
[]
llm
.
ImageData
,
params
NewSequenceParams
)
(
*
Sequence
,
error
)
{
func
(
s
*
Server
)
NewSequence
(
prompt
string
,
images
[]
llm
.
ImageData
,
params
NewSequenceParams
)
(
*
Sequence
,
error
)
{
s
.
ready
.
Wait
()
s
.
ready
.
Wait
()
...
@@ -125,6 +132,11 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
...
@@ -125,6 +132,11 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
if
int32
(
len
(
inputs
))
>
s
.
cache
.
numCtx
{
if
int32
(
len
(
inputs
))
>
s
.
cache
.
numCtx
{
discard
:=
int32
(
len
(
inputs
))
-
s
.
cache
.
numCtx
discard
:=
int32
(
len
(
inputs
))
-
s
.
cache
.
numCtx
if
!
params
.
truncate
{
return
nil
,
errorInputTooLong
}
promptStart
:=
params
.
numKeep
+
discard
promptStart
:=
params
.
numKeep
+
discard
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
...
@@ -176,6 +188,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
...
@@ -176,6 +188,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
embeddingOnly
:
params
.
embedding
,
embeddingOnly
:
params
.
embedding
,
stop
:
params
.
stop
,
stop
:
params
.
stop
,
numKeep
:
params
.
numKeep
,
numKeep
:
params
.
numKeep
,
shift
:
params
.
shift
,
},
nil
},
nil
}
}
...
@@ -517,6 +530,12 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
...
@@ -517,6 +530,12 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
break
break
}
}
if
!
seq
.
shift
{
s
.
removeSequence
(
seqIdx
,
llm
.
DoneReasonLength
)
nextBatch
.
seqs
[
seqIdx
]
=
nil
break
}
err
=
s
.
cache
.
ShiftCacheSlot
(
seq
.
cache
,
seq
.
numKeep
)
err
=
s
.
cache
.
ShiftCacheSlot
(
seq
.
cache
,
seq
.
numKeep
)
if
err
!=
nil
{
if
err
!=
nil
{
var
reprocess
*
ErrReprocessInputs
var
reprocess
*
ErrReprocessInputs
...
@@ -832,8 +851,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
...
@@ -832,8 +851,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
numKeep
:
int32
(
req
.
Options
.
NumKeep
),
numKeep
:
int32
(
req
.
Options
.
NumKeep
),
sampler
:
sampler
,
sampler
:
sampler
,
embedding
:
false
,
embedding
:
false
,
shift
:
req
.
Shift
,
truncate
:
req
.
Truncate
,
})
})
if
err
!=
nil
{
if
err
!=
nil
{
if
errors
.
Is
(
err
,
errorInputTooLong
)
{
http
.
Error
(
w
,
err
.
Error
(),
http
.
StatusBadRequest
)
return
}
http
.
Error
(
w
,
fmt
.
Sprintf
(
"Failed to create new sequence: %v"
,
err
),
http
.
StatusInternalServerError
)
http
.
Error
(
w
,
fmt
.
Sprintf
(
"Failed to create new sequence: %v"
,
err
),
http
.
StatusInternalServerError
)
return
return
}
}
...
...
server/prompt.go
View file @
6544e147
...
@@ -20,7 +20,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
...
@@ -20,7 +20,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
// latest message and 2) system messages
func
chatPrompt
(
ctx
context
.
Context
,
m
*
Model
,
tokenize
tokenizeFunc
,
opts
*
api
.
Options
,
msgs
[]
api
.
Message
,
tools
[]
api
.
Tool
,
think
*
api
.
ThinkValue
)
(
prompt
string
,
images
[]
llm
.
ImageData
,
_
error
)
{
func
chatPrompt
(
ctx
context
.
Context
,
m
*
Model
,
tokenize
tokenizeFunc
,
opts
*
api
.
Options
,
msgs
[]
api
.
Message
,
tools
[]
api
.
Tool
,
think
*
api
.
ThinkValue
,
truncate
bool
)
(
prompt
string
,
images
[]
llm
.
ImageData
,
_
error
)
{
var
system
[]
api
.
Message
var
system
[]
api
.
Message
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
...
@@ -59,7 +59,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
...
@@ -59,7 +59,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
}
}
}
}
if
ctxLen
>
opts
.
NumCtx
{
if
truncate
&&
ctxLen
>
opts
.
NumCtx
{
slog
.
Debug
(
"truncating input messages which exceed context length"
,
"truncated"
,
len
(
msgs
[
i
:
]))
slog
.
Debug
(
"truncating input messages which exceed context length"
,
"truncated"
,
len
(
msgs
[
i
:
]))
break
break
}
else
{
}
else
{
...
...
server/prompt_test.go
View file @
6544e147
...
@@ -27,16 +27,18 @@ func TestChatPrompt(t *testing.T) {
...
@@ -27,16 +27,18 @@ func TestChatPrompt(t *testing.T) {
visionModel
:=
Model
{
Template
:
tmpl
,
ProjectorPaths
:
[]
string
{
"vision"
}}
visionModel
:=
Model
{
Template
:
tmpl
,
ProjectorPaths
:
[]
string
{
"vision"
}}
cases
:=
[]
struct
{
cases
:=
[]
struct
{
name
string
name
string
model
Model
model
Model
limit
int
limit
int
msgs
[]
api
.
Message
truncate
bool
msgs
[]
api
.
Message
expect
expect
}{
}{
{
{
name
:
"messages"
,
name
:
"messages"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
64
,
limit
:
64
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
...
@@ -47,9 +49,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -47,9 +49,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"truncate messages"
,
name
:
"truncate messages"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
1
,
limit
:
1
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
...
@@ -60,9 +63,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -60,9 +63,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"truncate messages with image"
,
name
:
"truncate messages with image"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
64
,
limit
:
64
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
...
@@ -76,9 +80,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -76,9 +80,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"truncate messages with images"
,
name
:
"truncate messages with images"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
64
,
limit
:
64
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
...
@@ -92,9 +97,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -92,9 +97,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"messages with images"
,
name
:
"messages with images"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
2048
,
limit
:
2048
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
...
@@ -109,9 +115,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -109,9 +115,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"message with image tag"
,
name
:
"message with image tag"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
2048
,
limit
:
2048
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry! [img]"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Content
:
"You're a test, Harry! [img]"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
...
@@ -126,9 +133,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -126,9 +133,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"messages with interleaved images"
,
name
:
"messages with interleaved images"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
2048
,
limit
:
2048
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
...
@@ -145,9 +153,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -145,9 +153,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"truncate message with interleaved images"
,
name
:
"truncate message with interleaved images"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
1024
,
limit
:
1024
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
...
@@ -163,9 +172,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -163,9 +172,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"message with system prompt"
,
name
:
"message with system prompt"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
2048
,
limit
:
2048
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"system"
,
Content
:
"You are the Test Who Lived."
},
{
Role
:
"system"
,
Content
:
"You are the Test Who Lived."
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
...
@@ -177,9 +187,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -177,9 +187,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"out of order system"
,
name
:
"out of order system"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
2048
,
limit
:
2048
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
...
@@ -191,9 +202,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -191,9 +202,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"multiple images same prompt"
,
name
:
"multiple images same prompt"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
2048
,
limit
:
2048
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"Compare these two pictures of hotdogs"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"one hotdog"
),
[]
byte
(
"two hotdogs"
)}},
{
Role
:
"user"
,
Content
:
"Compare these two pictures of hotdogs"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"one hotdog"
),
[]
byte
(
"two hotdogs"
)}},
},
},
...
@@ -202,6 +214,20 @@ func TestChatPrompt(t *testing.T) {
...
@@ -202,6 +214,20 @@ func TestChatPrompt(t *testing.T) {
images
:
[][]
byte
{[]
byte
(
"one hotdog"
),
[]
byte
(
"two hotdogs"
)},
images
:
[][]
byte
{[]
byte
(
"one hotdog"
),
[]
byte
(
"two hotdogs"
)},
},
},
},
},
{
name
:
"no truncate with limit exceeded"
,
model
:
visionModel
,
limit
:
10
,
truncate
:
false
,
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"user"
,
Content
:
"A test. And a thumping good one at that, I'd wager."
},
},
expect
:
expect
{
prompt
:
"You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. "
,
},
},
}
}
for
_
,
tt
:=
range
cases
{
for
_
,
tt
:=
range
cases
{
...
@@ -209,7 +235,7 @@ func TestChatPrompt(t *testing.T) {
...
@@ -209,7 +235,7 @@ func TestChatPrompt(t *testing.T) {
model
:=
tt
.
model
model
:=
tt
.
model
opts
:=
api
.
Options
{
Runner
:
api
.
Runner
{
NumCtx
:
tt
.
limit
}}
opts
:=
api
.
Options
{
Runner
:
api
.
Runner
{
NumCtx
:
tt
.
limit
}}
think
:=
false
think
:=
false
prompt
,
images
,
err
:=
chatPrompt
(
t
.
Context
(),
&
model
,
mockRunner
{}
.
Tokenize
,
&
opts
,
tt
.
msgs
,
nil
,
&
api
.
ThinkValue
{
Value
:
think
})
prompt
,
images
,
err
:=
chatPrompt
(
t
.
Context
(),
&
model
,
mockRunner
{}
.
Tokenize
,
&
opts
,
tt
.
msgs
,
nil
,
&
api
.
ThinkValue
{
Value
:
think
}
,
tt
.
truncate
)
if
tt
.
error
==
nil
&&
err
!=
nil
{
if
tt
.
error
==
nil
&&
err
!=
nil
{
t
.
Fatal
(
err
)
t
.
Fatal
(
err
)
}
else
if
tt
.
error
!=
nil
&&
err
!=
tt
.
error
{
}
else
if
tt
.
error
!=
nil
&&
err
!=
tt
.
error
{
...
...
server/routes.go
View file @
6544e147
...
@@ -434,7 +434,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
...
@@ -434,7 +434,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// the real chat handler, but doing this as a stopgap to get renderer
// the real chat handler, but doing this as a stopgap to get renderer
// support for generate
// support for generate
if
values
.
Messages
!=
nil
&&
values
.
Suffix
==
""
&&
req
.
Template
==
""
{
if
values
.
Messages
!=
nil
&&
values
.
Suffix
==
""
&&
req
.
Template
==
""
{
prompt
,
images
,
err
=
chatPrompt
(
c
.
Request
.
Context
(),
m
,
r
.
Tokenize
,
opts
,
values
.
Messages
,
[]
api
.
Tool
{},
req
.
Think
)
prompt
,
images
,
err
=
chatPrompt
(
c
.
Request
.
Context
(),
m
,
r
.
Tokenize
,
opts
,
values
.
Messages
,
[]
api
.
Tool
{},
req
.
Think
,
req
.
Truncate
==
nil
||
*
req
.
Truncate
)
if
err
!=
nil
{
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
return
...
@@ -488,10 +488,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
...
@@ -488,10 +488,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var
sb
strings
.
Builder
var
sb
strings
.
Builder
defer
close
(
ch
)
defer
close
(
ch
)
if
err
:=
r
.
Completion
(
c
.
Request
.
Context
(),
llm
.
CompletionRequest
{
if
err
:=
r
.
Completion
(
c
.
Request
.
Context
(),
llm
.
CompletionRequest
{
Prompt
:
prompt
,
Prompt
:
prompt
,
Images
:
images
,
Images
:
images
,
Format
:
req
.
Format
,
Format
:
req
.
Format
,
Options
:
opts
,
Options
:
opts
,
Shift
:
req
.
Shift
==
nil
||
*
req
.
Shift
,
Truncate
:
req
.
Truncate
==
nil
||
*
req
.
Truncate
,
},
func
(
cr
llm
.
CompletionResponse
)
{
},
func
(
cr
llm
.
CompletionResponse
)
{
res
:=
api
.
GenerateResponse
{
res
:=
api
.
GenerateResponse
{
Model
:
req
.
Model
,
Model
:
req
.
Model
,
...
@@ -553,7 +555,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
...
@@ -553,7 +555,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
ch
<-
res
ch
<-
res
});
err
!=
nil
{
});
err
!=
nil
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
var
serr
api
.
StatusError
if
errors
.
As
(
err
,
&
serr
)
{
ch
<-
gin
.
H
{
"error"
:
serr
.
ErrorMessage
,
"status"
:
serr
.
StatusCode
}
}
else
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
}
}
}
}()
}()
...
@@ -573,7 +580,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
...
@@ -573,7 +580,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
msg
=
"unexpected error format in response"
msg
=
"unexpected error format in response"
}
}
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
msg
})
status
,
ok
:=
t
[
"status"
]
.
(
int
)
if
!
ok
{
status
=
http
.
StatusInternalServerError
}
c
.
JSON
(
status
,
gin
.
H
{
"error"
:
msg
})
return
return
default
:
default
:
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"unexpected response"
})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"unexpected response"
})
...
@@ -1638,6 +1650,30 @@ func streamResponse(c *gin.Context, ch chan any) {
...
@@ -1638,6 +1650,30 @@ func streamResponse(c *gin.Context, ch chan any) {
return
false
return
false
}
}
// errors are provided as a gin.H with an "error" field and
// an optional "status" field. For errors that are streamed
// before any content, we need to set the status code and
// content type for the error.
if
h
,
ok
:=
val
.
(
gin
.
H
);
ok
{
if
e
,
ok
:=
h
[
"error"
]
.
(
string
);
ok
{
status
,
ok
:=
h
[
"status"
]
.
(
int
)
if
!
ok
{
status
=
http
.
StatusInternalServerError
}
if
!
c
.
Writer
.
Written
()
{
c
.
Header
(
"Content-Type"
,
"application/json"
)
c
.
JSON
(
status
,
gin
.
H
{
"error"
:
e
})
}
else
{
if
err
:=
json
.
NewEncoder
(
c
.
Writer
)
.
Encode
(
gin
.
H
{
"error"
:
e
});
err
!=
nil
{
slog
.
Error
(
"streamResponse failed to encode json error"
,
"error"
,
err
)
}
}
return
false
}
}
bts
,
err
:=
json
.
Marshal
(
val
)
bts
,
err
:=
json
.
Marshal
(
val
)
if
err
!=
nil
{
if
err
!=
nil
{
slog
.
Info
(
fmt
.
Sprintf
(
"streamResponse: json.Marshal failed with %s"
,
err
))
slog
.
Info
(
fmt
.
Sprintf
(
"streamResponse: json.Marshal failed with %s"
,
err
))
...
@@ -1957,7 +1993,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
...
@@ -1957,7 +1993,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
}
}
}
prompt
,
images
,
err
:=
chatPrompt
(
c
.
Request
.
Context
(),
m
,
r
.
Tokenize
,
opts
,
msgs
,
processedTools
,
req
.
Think
)
truncate
:=
req
.
Truncate
==
nil
||
*
req
.
Truncate
prompt
,
images
,
err
:=
chatPrompt
(
c
.
Request
.
Context
(),
m
,
r
.
Tokenize
,
opts
,
msgs
,
processedTools
,
req
.
Think
,
truncate
)
if
err
!=
nil
{
if
err
!=
nil
{
slog
.
Error
(
"chat prompt error"
,
"error"
,
err
)
slog
.
Error
(
"chat prompt error"
,
"error"
,
err
)
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
...
@@ -2034,10 +2071,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
...
@@ -2034,10 +2071,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
// sets up new context given parent context per request
// sets up new context given parent context per request
ctx
,
cancel
:=
context
.
WithCancel
(
c
.
Request
.
Context
())
ctx
,
cancel
:=
context
.
WithCancel
(
c
.
Request
.
Context
())
err
:=
r
.
Completion
(
ctx
,
llm
.
CompletionRequest
{
err
:=
r
.
Completion
(
ctx
,
llm
.
CompletionRequest
{
Prompt
:
prompt
,
Prompt
:
prompt
,
Images
:
images
,
Images
:
images
,
Format
:
currentFormat
,
Format
:
currentFormat
,
Options
:
opts
,
Options
:
opts
,
Shift
:
req
.
Shift
==
nil
||
*
req
.
Shift
,
Truncate
:
truncate
,
},
func
(
r
llm
.
CompletionResponse
)
{
},
func
(
r
llm
.
CompletionResponse
)
{
res
:=
api
.
ChatResponse
{
res
:=
api
.
ChatResponse
{
Model
:
req
.
Model
,
Model
:
req
.
Model
,
...
@@ -2131,7 +2170,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
...
@@ -2131,7 +2170,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
if
structuredOutputsState
==
structuredOutputsState_ReadyToApply
&&
strings
.
Contains
(
err
.
Error
(),
"context canceled"
)
&&
c
.
Request
.
Context
()
.
Err
()
==
nil
{
if
structuredOutputsState
==
structuredOutputsState_ReadyToApply
&&
strings
.
Contains
(
err
.
Error
(),
"context canceled"
)
&&
c
.
Request
.
Context
()
.
Err
()
==
nil
{
// only ignores error if it's a context cancellation due to setting structured outputs
// only ignores error if it's a context cancellation due to setting structured outputs
}
else
{
}
else
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
var
serr
api
.
StatusError
if
errors
.
As
(
err
,
&
serr
)
{
ch
<-
gin
.
H
{
"error"
:
serr
.
ErrorMessage
,
"status"
:
serr
.
StatusCode
}
}
else
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
}
return
return
}
}
}
}
...
@@ -2145,7 +2189,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
...
@@ -2145,7 +2189,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
}
msgs
=
append
(
msgs
,
msg
)
msgs
=
append
(
msgs
,
msg
)
prompt
,
_
,
err
=
chatPrompt
(
c
.
Request
.
Context
(),
m
,
r
.
Tokenize
,
opts
,
msgs
,
processedTools
,
req
.
Think
)
prompt
,
_
,
err
=
chatPrompt
(
c
.
Request
.
Context
(),
m
,
r
.
Tokenize
,
opts
,
msgs
,
processedTools
,
req
.
Think
,
truncate
)
if
err
!=
nil
{
if
err
!=
nil
{
slog
.
Error
(
"chat prompt error applying structured outputs"
,
"error"
,
err
)
slog
.
Error
(
"chat prompt error applying structured outputs"
,
"error"
,
err
)
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
...
@@ -2185,7 +2229,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
...
@@ -2185,7 +2229,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
msg
=
"unexpected error format in response"
msg
=
"unexpected error format in response"
}
}
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
msg
})
status
,
ok
:=
t
[
"status"
]
.
(
int
)
if
!
ok
{
status
=
http
.
StatusInternalServerError
}
c
.
JSON
(
status
,
gin
.
H
{
"error"
:
msg
})
return
return
default
:
default
:
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"unexpected response"
})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"unexpected response"
})
...
...
server/routes_generate_test.go
View file @
6544e147
...
@@ -609,6 +609,58 @@ func TestGenerateChat(t *testing.T) {
...
@@ -609,6 +609,58 @@ func TestGenerateChat(t *testing.T) {
t
.
Errorf
(
"final tool call mismatch (-got +want):
\n
%s"
,
diff
)
t
.
Errorf
(
"final tool call mismatch (-got +want):
\n
%s"
,
diff
)
}
}
})
})
t
.
Run
(
"status error non-streaming"
,
func
(
t
*
testing
.
T
)
{
mock
.
CompletionFn
=
func
(
ctx
context
.
Context
,
r
llm
.
CompletionRequest
,
fn
func
(
r
llm
.
CompletionResponse
))
error
{
return
api
.
StatusError
{
StatusCode
:
http
.
StatusServiceUnavailable
,
Status
:
"Service Unavailable"
,
ErrorMessage
:
"model is overloaded"
,
}
}
stream
:=
false
w
:=
createRequest
(
t
,
s
.
ChatHandler
,
api
.
ChatRequest
{
Model
:
"test"
,
Messages
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"Hello!"
},
},
Stream
:
&
stream
,
})
if
w
.
Code
!=
http
.
StatusServiceUnavailable
{
t
.
Errorf
(
"expected status 503, got %d"
,
w
.
Code
)
}
if
diff
:=
cmp
.
Diff
(
w
.
Body
.
String
(),
`{"error":"model is overloaded"}`
);
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
t
.
Run
(
"status error streaming"
,
func
(
t
*
testing
.
T
)
{
mock
.
CompletionFn
=
func
(
ctx
context
.
Context
,
r
llm
.
CompletionRequest
,
fn
func
(
r
llm
.
CompletionResponse
))
error
{
return
api
.
StatusError
{
StatusCode
:
http
.
StatusTooManyRequests
,
Status
:
"Too Many Requests"
,
ErrorMessage
:
"rate limit exceeded"
,
}
}
w
:=
createRequest
(
t
,
s
.
ChatHandler
,
api
.
ChatRequest
{
Model
:
"test"
,
Messages
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"Hello!"
},
},
})
if
w
.
Code
!=
http
.
StatusTooManyRequests
{
t
.
Errorf
(
"expected status 429, got %d"
,
w
.
Code
)
}
if
diff
:=
cmp
.
Diff
(
w
.
Body
.
String
(),
`{"error":"rate limit exceeded"}`
);
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
}
}
func
TestGenerate
(
t
*
testing
.
T
)
{
func
TestGenerate
(
t
*
testing
.
T
)
{
...
@@ -983,6 +1035,55 @@ func TestGenerate(t *testing.T) {
...
@@ -983,6 +1035,55 @@ func TestGenerate(t *testing.T) {
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
}
})
})
t
.
Run
(
"status error non-streaming"
,
func
(
t
*
testing
.
T
)
{
mock
.
CompletionFn
=
func
(
ctx
context
.
Context
,
r
llm
.
CompletionRequest
,
fn
func
(
r
llm
.
CompletionResponse
))
error
{
return
api
.
StatusError
{
StatusCode
:
http
.
StatusServiceUnavailable
,
Status
:
"Service Unavailable"
,
ErrorMessage
:
"model is overloaded"
,
}
}
streamRequest
:=
false
w
:=
createRequest
(
t
,
s
.
GenerateHandler
,
api
.
GenerateRequest
{
Model
:
"test"
,
Prompt
:
"Hello!"
,
Stream
:
&
streamRequest
,
})
if
w
.
Code
!=
http
.
StatusServiceUnavailable
{
t
.
Errorf
(
"expected status 503, got %d"
,
w
.
Code
)
}
if
diff
:=
cmp
.
Diff
(
w
.
Body
.
String
(),
`{"error":"model is overloaded"}`
);
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
t
.
Run
(
"status error streaming"
,
func
(
t
*
testing
.
T
)
{
mock
.
CompletionFn
=
func
(
ctx
context
.
Context
,
r
llm
.
CompletionRequest
,
fn
func
(
r
llm
.
CompletionResponse
))
error
{
return
api
.
StatusError
{
StatusCode
:
http
.
StatusTooManyRequests
,
Status
:
"Too Many Requests"
,
ErrorMessage
:
"rate limit exceeded"
,
}
}
w
:=
createRequest
(
t
,
s
.
GenerateHandler
,
api
.
GenerateRequest
{
Model
:
"test"
,
Prompt
:
"Hello!"
,
Stream
:
&
stream
,
})
if
w
.
Code
!=
http
.
StatusTooManyRequests
{
t
.
Errorf
(
"expected status 429, got %d"
,
w
.
Code
)
}
if
diff
:=
cmp
.
Diff
(
w
.
Body
.
String
(),
`{"error":"rate limit exceeded"}`
);
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
}
}
func
TestChatWithPromptEndingInThinkTag
(
t
*
testing
.
T
)
{
func
TestChatWithPromptEndingInThinkTag
(
t
*
testing
.
T
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment