Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
6544e147
Unverified
Commit
6544e147
authored
Oct 11, 2025
by
Jeffrey Morgan
Committed by
GitHub
Oct 11, 2025
Browse files
Reapply "add truncate and shift parameters" (#12582)
parent
5db8a818
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
298 additions
and
57 deletions
+298
-57
api/types.go
api/types.go
+16
-0
llm/server.go
llm/server.go
+4
-2
runner/llamarunner/runner.go
runner/llamarunner/runner.go
+22
-0
runner/ollamarunner/runner.go
runner/ollamarunner/runner.go
+25
-0
server/prompt.go
server/prompt.go
+2
-2
server/prompt_test.go
server/prompt_test.go
+64
-38
server/routes.go
server/routes.go
+64
-15
server/routes_generate_test.go
server/routes_generate_test.go
+101
-0
No files found.
api/types.go
View file @
6544e147
...
@@ -106,6 +106,14 @@ type GenerateRequest struct {
...
@@ -106,6 +106,14 @@ type GenerateRequest struct {
// before this option was introduced)
// before this option was introduced)
Think
*
ThinkValue
`json:"think,omitempty"`
Think
*
ThinkValue
`json:"think,omitempty"`
// Truncate is a boolean that, when set to true, truncates the chat history messages
// if the rendered prompt exceeds the context length limit.
Truncate
*
bool
`json:"truncate,omitempty"`
// Shift is a boolean that, when set to true, shifts the chat history
// when hitting the context length limit instead of erroring.
Shift
*
bool
`json:"shift,omitempty"`
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model.
// template instead of calling the model.
DebugRenderOnly
bool
`json:"_debug_render_only,omitempty"`
DebugRenderOnly
bool
`json:"_debug_render_only,omitempty"`
...
@@ -140,6 +148,14 @@ type ChatRequest struct {
...
@@ -140,6 +148,14 @@ type ChatRequest struct {
// for supported models.
// for supported models.
Think
*
ThinkValue
`json:"think,omitempty"`
Think
*
ThinkValue
`json:"think,omitempty"`
// Truncate is a boolean that, when set to true, truncates the chat history messages
// if the rendered prompt exceeds the context length limit.
Truncate
*
bool
`json:"truncate,omitempty"`
// Shift is a boolean that, when set to true, shifts the chat history
// when hitting the context length limit instead of erroring.
Shift
*
bool
`json:"shift,omitempty"`
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model.
// template instead of calling the model.
DebugRenderOnly
bool
`json:"_debug_render_only,omitempty"`
DebugRenderOnly
bool
`json:"_debug_render_only,omitempty"`
...
...
llm/server.go
View file @
6544e147
...
@@ -1379,7 +1379,9 @@ type CompletionRequest struct {
...
@@ -1379,7 +1379,9 @@ type CompletionRequest struct {
Images
[]
ImageData
Images
[]
ImageData
Options
*
api
.
Options
Options
*
api
.
Options
Grammar
string
// set before sending the request to the subprocess
Grammar
string
// set before sending the request to the subprocess
Shift
bool
Truncate
bool
}
}
// DoneReason represents the reason why a completion response is done
// DoneReason represents the reason why a completion response is done
...
@@ -1501,7 +1503,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
...
@@ -1501,7 +1503,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return
fmt
.
Errorf
(
"failed reading llm error response: %w"
,
err
)
return
fmt
.
Errorf
(
"failed reading llm error response: %w"
,
err
)
}
}
log
.
Printf
(
"llm predict error: %s"
,
bodyBytes
)
log
.
Printf
(
"llm predict error: %s"
,
bodyBytes
)
return
fmt
.
Errorf
(
"%s"
,
bodyBytes
)
return
api
.
StatusError
{
StatusCode
:
res
.
StatusCode
,
ErrorMessage
:
strings
.
TrimSpace
(
string
(
bodyBytes
)
)}
}
}
scanner
:=
bufio
.
NewScanner
(
res
.
Body
)
scanner
:=
bufio
.
NewScanner
(
res
.
Body
)
...
...
runner/llamarunner/runner.go
View file @
6544e147
...
@@ -79,6 +79,9 @@ type Sequence struct {
...
@@ -79,6 +79,9 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
// true if an embedding are to be returned instead of text generation
embeddingOnly
bool
embeddingOnly
bool
// shift if context window is exceeded
shift
bool
doneReason
llm
.
DoneReason
doneReason
llm
.
DoneReason
// Metrics
// Metrics
...
@@ -94,8 +97,12 @@ type NewSequenceParams struct {
...
@@ -94,8 +97,12 @@ type NewSequenceParams struct {
numKeep
int
numKeep
int
samplingParams
*
llama
.
SamplingParams
samplingParams
*
llama
.
SamplingParams
embedding
bool
embedding
bool
shift
bool
truncate
bool
}
}
var
errorInputTooLong
=
errors
.
New
(
"the input length exceeds the context length"
)
func
(
s
*
Server
)
NewSequence
(
prompt
string
,
images
[]
llm
.
ImageData
,
params
NewSequenceParams
)
(
*
Sequence
,
error
)
{
func
(
s
*
Server
)
NewSequence
(
prompt
string
,
images
[]
llm
.
ImageData
,
params
NewSequenceParams
)
(
*
Sequence
,
error
)
{
s
.
ready
.
Wait
()
s
.
ready
.
Wait
()
...
@@ -119,6 +126,10 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
...
@@ -119,6 +126,10 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
if
len
(
inputs
)
>
s
.
cache
.
numCtx
{
if
len
(
inputs
)
>
s
.
cache
.
numCtx
{
discard
:=
len
(
inputs
)
-
s
.
cache
.
numCtx
discard
:=
len
(
inputs
)
-
s
.
cache
.
numCtx
if
!
params
.
truncate
{
return
nil
,
errorInputTooLong
}
newInputs
:=
inputs
[
:
params
.
numKeep
]
newInputs
:=
inputs
[
:
params
.
numKeep
]
newInputs
=
append
(
newInputs
,
inputs
[
params
.
numKeep
+
discard
:
]
...
)
newInputs
=
append
(
newInputs
,
inputs
[
params
.
numKeep
+
discard
:
]
...
)
...
@@ -385,6 +396,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
...
@@ -385,6 +396,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
for
i
,
input
:=
range
seq
.
inputs
{
for
i
,
input
:=
range
seq
.
inputs
{
if
len
(
seq
.
cache
.
Inputs
)
+
len
(
seq
.
pendingInputs
)
+
1
>
s
.
cache
.
numCtx
{
if
len
(
seq
.
cache
.
Inputs
)
+
len
(
seq
.
pendingInputs
)
+
1
>
s
.
cache
.
numCtx
{
if
len
(
seq
.
pendingInputs
)
==
0
{
if
len
(
seq
.
pendingInputs
)
==
0
{
if
!
seq
.
shift
{
s
.
removeSequence
(
seqIdx
,
llm
.
DoneReasonLength
)
break
}
err
:=
s
.
cache
.
ShiftCacheSlot
(
seq
.
cache
,
seq
.
numKeep
)
err
:=
s
.
cache
.
ShiftCacheSlot
(
seq
.
cache
,
seq
.
numKeep
)
if
err
!=
nil
{
if
err
!=
nil
{
var
reprocess
*
ErrReprocessInputs
var
reprocess
*
ErrReprocessInputs
...
@@ -583,8 +599,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
...
@@ -583,8 +599,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
numKeep
:
req
.
Options
.
NumKeep
,
numKeep
:
req
.
Options
.
NumKeep
,
samplingParams
:
&
samplingParams
,
samplingParams
:
&
samplingParams
,
embedding
:
false
,
embedding
:
false
,
shift
:
req
.
Shift
,
truncate
:
req
.
Truncate
,
})
})
if
err
!=
nil
{
if
err
!=
nil
{
if
errors
.
Is
(
err
,
errorInputTooLong
)
{
http
.
Error
(
w
,
err
.
Error
(),
http
.
StatusBadRequest
)
return
}
http
.
Error
(
w
,
fmt
.
Sprintf
(
"Failed to create new sequence: %v"
,
err
),
http
.
StatusInternalServerError
)
http
.
Error
(
w
,
fmt
.
Sprintf
(
"Failed to create new sequence: %v"
,
err
),
http
.
StatusInternalServerError
)
return
return
}
}
...
...
runner/ollamarunner/runner.go
View file @
6544e147
...
@@ -88,6 +88,9 @@ type Sequence struct {
...
@@ -88,6 +88,9 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
// true if an embedding are to be returned instead of text generation
embeddingOnly
bool
embeddingOnly
bool
// shift if context window is exceeded
shift
bool
doneReason
llm
.
DoneReason
doneReason
llm
.
DoneReason
// Metrics
// Metrics
...
@@ -104,8 +107,12 @@ type NewSequenceParams struct {
...
@@ -104,8 +107,12 @@ type NewSequenceParams struct {
numKeep
int32
numKeep
int32
sampler
sample
.
Sampler
sampler
sample
.
Sampler
embedding
bool
embedding
bool
shift
bool
truncate
bool
}
}
var
errorInputTooLong
=
errors
.
New
(
"the input length exceeds the context length"
)
func
(
s
*
Server
)
NewSequence
(
prompt
string
,
images
[]
llm
.
ImageData
,
params
NewSequenceParams
)
(
*
Sequence
,
error
)
{
func
(
s
*
Server
)
NewSequence
(
prompt
string
,
images
[]
llm
.
ImageData
,
params
NewSequenceParams
)
(
*
Sequence
,
error
)
{
s
.
ready
.
Wait
()
s
.
ready
.
Wait
()
...
@@ -125,6 +132,11 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
...
@@ -125,6 +132,11 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
if
int32
(
len
(
inputs
))
>
s
.
cache
.
numCtx
{
if
int32
(
len
(
inputs
))
>
s
.
cache
.
numCtx
{
discard
:=
int32
(
len
(
inputs
))
-
s
.
cache
.
numCtx
discard
:=
int32
(
len
(
inputs
))
-
s
.
cache
.
numCtx
if
!
params
.
truncate
{
return
nil
,
errorInputTooLong
}
promptStart
:=
params
.
numKeep
+
discard
promptStart
:=
params
.
numKeep
+
discard
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
...
@@ -176,6 +188,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
...
@@ -176,6 +188,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
embeddingOnly
:
params
.
embedding
,
embeddingOnly
:
params
.
embedding
,
stop
:
params
.
stop
,
stop
:
params
.
stop
,
numKeep
:
params
.
numKeep
,
numKeep
:
params
.
numKeep
,
shift
:
params
.
shift
,
},
nil
},
nil
}
}
...
@@ -517,6 +530,12 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
...
@@ -517,6 +530,12 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
break
break
}
}
if
!
seq
.
shift
{
s
.
removeSequence
(
seqIdx
,
llm
.
DoneReasonLength
)
nextBatch
.
seqs
[
seqIdx
]
=
nil
break
}
err
=
s
.
cache
.
ShiftCacheSlot
(
seq
.
cache
,
seq
.
numKeep
)
err
=
s
.
cache
.
ShiftCacheSlot
(
seq
.
cache
,
seq
.
numKeep
)
if
err
!=
nil
{
if
err
!=
nil
{
var
reprocess
*
ErrReprocessInputs
var
reprocess
*
ErrReprocessInputs
...
@@ -832,8 +851,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
...
@@ -832,8 +851,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
numKeep
:
int32
(
req
.
Options
.
NumKeep
),
numKeep
:
int32
(
req
.
Options
.
NumKeep
),
sampler
:
sampler
,
sampler
:
sampler
,
embedding
:
false
,
embedding
:
false
,
shift
:
req
.
Shift
,
truncate
:
req
.
Truncate
,
})
})
if
err
!=
nil
{
if
err
!=
nil
{
if
errors
.
Is
(
err
,
errorInputTooLong
)
{
http
.
Error
(
w
,
err
.
Error
(),
http
.
StatusBadRequest
)
return
}
http
.
Error
(
w
,
fmt
.
Sprintf
(
"Failed to create new sequence: %v"
,
err
),
http
.
StatusInternalServerError
)
http
.
Error
(
w
,
fmt
.
Sprintf
(
"Failed to create new sequence: %v"
,
err
),
http
.
StatusInternalServerError
)
return
return
}
}
...
...
server/prompt.go
View file @
6544e147
...
@@ -20,7 +20,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
...
@@ -20,7 +20,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
// latest message and 2) system messages
func
chatPrompt
(
ctx
context
.
Context
,
m
*
Model
,
tokenize
tokenizeFunc
,
opts
*
api
.
Options
,
msgs
[]
api
.
Message
,
tools
[]
api
.
Tool
,
think
*
api
.
ThinkValue
)
(
prompt
string
,
images
[]
llm
.
ImageData
,
_
error
)
{
func
chatPrompt
(
ctx
context
.
Context
,
m
*
Model
,
tokenize
tokenizeFunc
,
opts
*
api
.
Options
,
msgs
[]
api
.
Message
,
tools
[]
api
.
Tool
,
think
*
api
.
ThinkValue
,
truncate
bool
)
(
prompt
string
,
images
[]
llm
.
ImageData
,
_
error
)
{
var
system
[]
api
.
Message
var
system
[]
api
.
Message
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
...
@@ -59,7 +59,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
...
@@ -59,7 +59,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
}
}
}
}
if
ctxLen
>
opts
.
NumCtx
{
if
truncate
&&
ctxLen
>
opts
.
NumCtx
{
slog
.
Debug
(
"truncating input messages which exceed context length"
,
"truncated"
,
len
(
msgs
[
i
:
]))
slog
.
Debug
(
"truncating input messages which exceed context length"
,
"truncated"
,
len
(
msgs
[
i
:
]))
break
break
}
else
{
}
else
{
...
...
server/prompt_test.go
View file @
6544e147
...
@@ -27,16 +27,18 @@ func TestChatPrompt(t *testing.T) {
...
@@ -27,16 +27,18 @@ func TestChatPrompt(t *testing.T) {
visionModel
:=
Model
{
Template
:
tmpl
,
ProjectorPaths
:
[]
string
{
"vision"
}}
visionModel
:=
Model
{
Template
:
tmpl
,
ProjectorPaths
:
[]
string
{
"vision"
}}
cases
:=
[]
struct
{
cases
:=
[]
struct
{
name
string
name
string
model
Model
model
Model
limit
int
limit
int
msgs
[]
api
.
Message
truncate
bool
msgs
[]
api
.
Message
expect
expect
}{
}{
{
{
name
:
"messages"
,
name
:
"messages"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
64
,
limit
:
64
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
...
@@ -47,9 +49,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -47,9 +49,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"truncate messages"
,
name
:
"truncate messages"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
1
,
limit
:
1
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
...
@@ -60,9 +63,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -60,9 +63,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"truncate messages with image"
,
name
:
"truncate messages with image"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
64
,
limit
:
64
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
...
@@ -76,9 +80,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -76,9 +80,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"truncate messages with images"
,
name
:
"truncate messages with images"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
64
,
limit
:
64
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
...
@@ -92,9 +97,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -92,9 +97,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"messages with images"
,
name
:
"messages with images"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
2048
,
limit
:
2048
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
...
@@ -109,9 +115,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -109,9 +115,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"message with image tag"
,
name
:
"message with image tag"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
2048
,
limit
:
2048
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry! [img]"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Content
:
"You're a test, Harry! [img]"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
...
@@ -126,9 +133,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -126,9 +133,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"messages with interleaved images"
,
name
:
"messages with interleaved images"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
2048
,
limit
:
2048
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
...
@@ -145,9 +153,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -145,9 +153,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"truncate message with interleaved images"
,
name
:
"truncate message with interleaved images"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
1024
,
limit
:
1024
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
{
Role
:
"user"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"something"
)}},
...
@@ -163,9 +172,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -163,9 +172,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"message with system prompt"
,
name
:
"message with system prompt"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
2048
,
limit
:
2048
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"system"
,
Content
:
"You are the Test Who Lived."
},
{
Role
:
"system"
,
Content
:
"You are the Test Who Lived."
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
...
@@ -177,9 +187,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -177,9 +187,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"out of order system"
,
name
:
"out of order system"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
2048
,
limit
:
2048
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
...
@@ -191,9 +202,10 @@ func TestChatPrompt(t *testing.T) {
...
@@ -191,9 +202,10 @@ func TestChatPrompt(t *testing.T) {
},
},
},
},
{
{
name
:
"multiple images same prompt"
,
name
:
"multiple images same prompt"
,
model
:
visionModel
,
model
:
visionModel
,
limit
:
2048
,
limit
:
2048
,
truncate
:
true
,
msgs
:
[]
api
.
Message
{
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"Compare these two pictures of hotdogs"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"one hotdog"
),
[]
byte
(
"two hotdogs"
)}},
{
Role
:
"user"
,
Content
:
"Compare these two pictures of hotdogs"
,
Images
:
[]
api
.
ImageData
{[]
byte
(
"one hotdog"
),
[]
byte
(
"two hotdogs"
)}},
},
},
...
@@ -202,6 +214,20 @@ func TestChatPrompt(t *testing.T) {
...
@@ -202,6 +214,20 @@ func TestChatPrompt(t *testing.T) {
images
:
[][]
byte
{[]
byte
(
"one hotdog"
),
[]
byte
(
"two hotdogs"
)},
images
:
[][]
byte
{[]
byte
(
"one hotdog"
),
[]
byte
(
"two hotdogs"
)},
},
},
},
},
{
name
:
"no truncate with limit exceeded"
,
model
:
visionModel
,
limit
:
10
,
truncate
:
false
,
msgs
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"You're a test, Harry!"
},
{
Role
:
"assistant"
,
Content
:
"I-I'm a what?"
},
{
Role
:
"user"
,
Content
:
"A test. And a thumping good one at that, I'd wager."
},
},
expect
:
expect
{
prompt
:
"You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. "
,
},
},
}
}
for
_
,
tt
:=
range
cases
{
for
_
,
tt
:=
range
cases
{
...
@@ -209,7 +235,7 @@ func TestChatPrompt(t *testing.T) {
...
@@ -209,7 +235,7 @@ func TestChatPrompt(t *testing.T) {
model
:=
tt
.
model
model
:=
tt
.
model
opts
:=
api
.
Options
{
Runner
:
api
.
Runner
{
NumCtx
:
tt
.
limit
}}
opts
:=
api
.
Options
{
Runner
:
api
.
Runner
{
NumCtx
:
tt
.
limit
}}
think
:=
false
think
:=
false
prompt
,
images
,
err
:=
chatPrompt
(
t
.
Context
(),
&
model
,
mockRunner
{}
.
Tokenize
,
&
opts
,
tt
.
msgs
,
nil
,
&
api
.
ThinkValue
{
Value
:
think
})
prompt
,
images
,
err
:=
chatPrompt
(
t
.
Context
(),
&
model
,
mockRunner
{}
.
Tokenize
,
&
opts
,
tt
.
msgs
,
nil
,
&
api
.
ThinkValue
{
Value
:
think
}
,
tt
.
truncate
)
if
tt
.
error
==
nil
&&
err
!=
nil
{
if
tt
.
error
==
nil
&&
err
!=
nil
{
t
.
Fatal
(
err
)
t
.
Fatal
(
err
)
}
else
if
tt
.
error
!=
nil
&&
err
!=
tt
.
error
{
}
else
if
tt
.
error
!=
nil
&&
err
!=
tt
.
error
{
...
...
server/routes.go
View file @
6544e147
...
@@ -434,7 +434,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
...
@@ -434,7 +434,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// the real chat handler, but doing this as a stopgap to get renderer
// the real chat handler, but doing this as a stopgap to get renderer
// support for generate
// support for generate
if
values
.
Messages
!=
nil
&&
values
.
Suffix
==
""
&&
req
.
Template
==
""
{
if
values
.
Messages
!=
nil
&&
values
.
Suffix
==
""
&&
req
.
Template
==
""
{
prompt
,
images
,
err
=
chatPrompt
(
c
.
Request
.
Context
(),
m
,
r
.
Tokenize
,
opts
,
values
.
Messages
,
[]
api
.
Tool
{},
req
.
Think
)
prompt
,
images
,
err
=
chatPrompt
(
c
.
Request
.
Context
(),
m
,
r
.
Tokenize
,
opts
,
values
.
Messages
,
[]
api
.
Tool
{},
req
.
Think
,
req
.
Truncate
==
nil
||
*
req
.
Truncate
)
if
err
!=
nil
{
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
return
...
@@ -488,10 +488,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
...
@@ -488,10 +488,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var
sb
strings
.
Builder
var
sb
strings
.
Builder
defer
close
(
ch
)
defer
close
(
ch
)
if
err
:=
r
.
Completion
(
c
.
Request
.
Context
(),
llm
.
CompletionRequest
{
if
err
:=
r
.
Completion
(
c
.
Request
.
Context
(),
llm
.
CompletionRequest
{
Prompt
:
prompt
,
Prompt
:
prompt
,
Images
:
images
,
Images
:
images
,
Format
:
req
.
Format
,
Format
:
req
.
Format
,
Options
:
opts
,
Options
:
opts
,
Shift
:
req
.
Shift
==
nil
||
*
req
.
Shift
,
Truncate
:
req
.
Truncate
==
nil
||
*
req
.
Truncate
,
},
func
(
cr
llm
.
CompletionResponse
)
{
},
func
(
cr
llm
.
CompletionResponse
)
{
res
:=
api
.
GenerateResponse
{
res
:=
api
.
GenerateResponse
{
Model
:
req
.
Model
,
Model
:
req
.
Model
,
...
@@ -553,7 +555,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
...
@@ -553,7 +555,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
ch
<-
res
ch
<-
res
});
err
!=
nil
{
});
err
!=
nil
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
var
serr
api
.
StatusError
if
errors
.
As
(
err
,
&
serr
)
{
ch
<-
gin
.
H
{
"error"
:
serr
.
ErrorMessage
,
"status"
:
serr
.
StatusCode
}
}
else
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
}
}
}
}()
}()
...
@@ -573,7 +580,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
...
@@ -573,7 +580,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
msg
=
"unexpected error format in response"
msg
=
"unexpected error format in response"
}
}
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
msg
})
status
,
ok
:=
t
[
"status"
]
.
(
int
)
if
!
ok
{
status
=
http
.
StatusInternalServerError
}
c
.
JSON
(
status
,
gin
.
H
{
"error"
:
msg
})
return
return
default
:
default
:
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"unexpected response"
})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"unexpected response"
})
...
@@ -1638,6 +1650,30 @@ func streamResponse(c *gin.Context, ch chan any) {
...
@@ -1638,6 +1650,30 @@ func streamResponse(c *gin.Context, ch chan any) {
return
false
return
false
}
}
// errors are provided as a gin.H with an "error" field and
// an optional "status" field. For errors that are streamed
// before any content, we need to set the status code and
// content type for the error.
if
h
,
ok
:=
val
.
(
gin
.
H
);
ok
{
if
e
,
ok
:=
h
[
"error"
]
.
(
string
);
ok
{
status
,
ok
:=
h
[
"status"
]
.
(
int
)
if
!
ok
{
status
=
http
.
StatusInternalServerError
}
if
!
c
.
Writer
.
Written
()
{
c
.
Header
(
"Content-Type"
,
"application/json"
)
c
.
JSON
(
status
,
gin
.
H
{
"error"
:
e
})
}
else
{
if
err
:=
json
.
NewEncoder
(
c
.
Writer
)
.
Encode
(
gin
.
H
{
"error"
:
e
});
err
!=
nil
{
slog
.
Error
(
"streamResponse failed to encode json error"
,
"error"
,
err
)
}
}
return
false
}
}
bts
,
err
:=
json
.
Marshal
(
val
)
bts
,
err
:=
json
.
Marshal
(
val
)
if
err
!=
nil
{
if
err
!=
nil
{
slog
.
Info
(
fmt
.
Sprintf
(
"streamResponse: json.Marshal failed with %s"
,
err
))
slog
.
Info
(
fmt
.
Sprintf
(
"streamResponse: json.Marshal failed with %s"
,
err
))
...
@@ -1957,7 +1993,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
...
@@ -1957,7 +1993,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
}
}
}
prompt
,
images
,
err
:=
chatPrompt
(
c
.
Request
.
Context
(),
m
,
r
.
Tokenize
,
opts
,
msgs
,
processedTools
,
req
.
Think
)
truncate
:=
req
.
Truncate
==
nil
||
*
req
.
Truncate
prompt
,
images
,
err
:=
chatPrompt
(
c
.
Request
.
Context
(),
m
,
r
.
Tokenize
,
opts
,
msgs
,
processedTools
,
req
.
Think
,
truncate
)
if
err
!=
nil
{
if
err
!=
nil
{
slog
.
Error
(
"chat prompt error"
,
"error"
,
err
)
slog
.
Error
(
"chat prompt error"
,
"error"
,
err
)
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
...
@@ -2034,10 +2071,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
...
@@ -2034,10 +2071,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
// sets up new context given parent context per request
// sets up new context given parent context per request
ctx
,
cancel
:=
context
.
WithCancel
(
c
.
Request
.
Context
())
ctx
,
cancel
:=
context
.
WithCancel
(
c
.
Request
.
Context
())
err
:=
r
.
Completion
(
ctx
,
llm
.
CompletionRequest
{
err
:=
r
.
Completion
(
ctx
,
llm
.
CompletionRequest
{
Prompt
:
prompt
,
Prompt
:
prompt
,
Images
:
images
,
Images
:
images
,
Format
:
currentFormat
,
Format
:
currentFormat
,
Options
:
opts
,
Options
:
opts
,
Shift
:
req
.
Shift
==
nil
||
*
req
.
Shift
,
Truncate
:
truncate
,
},
func
(
r
llm
.
CompletionResponse
)
{
},
func
(
r
llm
.
CompletionResponse
)
{
res
:=
api
.
ChatResponse
{
res
:=
api
.
ChatResponse
{
Model
:
req
.
Model
,
Model
:
req
.
Model
,
...
@@ -2131,7 +2170,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
...
@@ -2131,7 +2170,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
if
structuredOutputsState
==
structuredOutputsState_ReadyToApply
&&
strings
.
Contains
(
err
.
Error
(),
"context canceled"
)
&&
c
.
Request
.
Context
()
.
Err
()
==
nil
{
if
structuredOutputsState
==
structuredOutputsState_ReadyToApply
&&
strings
.
Contains
(
err
.
Error
(),
"context canceled"
)
&&
c
.
Request
.
Context
()
.
Err
()
==
nil
{
// only ignores error if it's a context cancellation due to setting structured outputs
// only ignores error if it's a context cancellation due to setting structured outputs
}
else
{
}
else
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
var
serr
api
.
StatusError
if
errors
.
As
(
err
,
&
serr
)
{
ch
<-
gin
.
H
{
"error"
:
serr
.
ErrorMessage
,
"status"
:
serr
.
StatusCode
}
}
else
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
}
return
return
}
}
}
}
...
@@ -2145,7 +2189,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
...
@@ -2145,7 +2189,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
}
msgs
=
append
(
msgs
,
msg
)
msgs
=
append
(
msgs
,
msg
)
prompt
,
_
,
err
=
chatPrompt
(
c
.
Request
.
Context
(),
m
,
r
.
Tokenize
,
opts
,
msgs
,
processedTools
,
req
.
Think
)
prompt
,
_
,
err
=
chatPrompt
(
c
.
Request
.
Context
(),
m
,
r
.
Tokenize
,
opts
,
msgs
,
processedTools
,
req
.
Think
,
truncate
)
if
err
!=
nil
{
if
err
!=
nil
{
slog
.
Error
(
"chat prompt error applying structured outputs"
,
"error"
,
err
)
slog
.
Error
(
"chat prompt error applying structured outputs"
,
"error"
,
err
)
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
...
@@ -2185,7 +2229,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
...
@@ -2185,7 +2229,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
msg
=
"unexpected error format in response"
msg
=
"unexpected error format in response"
}
}
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
msg
})
status
,
ok
:=
t
[
"status"
]
.
(
int
)
if
!
ok
{
status
=
http
.
StatusInternalServerError
}
c
.
JSON
(
status
,
gin
.
H
{
"error"
:
msg
})
return
return
default
:
default
:
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"unexpected response"
})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"unexpected response"
})
...
...
server/routes_generate_test.go
View file @
6544e147
...
@@ -609,6 +609,58 @@ func TestGenerateChat(t *testing.T) {
...
@@ -609,6 +609,58 @@ func TestGenerateChat(t *testing.T) {
t
.
Errorf
(
"final tool call mismatch (-got +want):
\n
%s"
,
diff
)
t
.
Errorf
(
"final tool call mismatch (-got +want):
\n
%s"
,
diff
)
}
}
})
})
t
.
Run
(
"status error non-streaming"
,
func
(
t
*
testing
.
T
)
{
mock
.
CompletionFn
=
func
(
ctx
context
.
Context
,
r
llm
.
CompletionRequest
,
fn
func
(
r
llm
.
CompletionResponse
))
error
{
return
api
.
StatusError
{
StatusCode
:
http
.
StatusServiceUnavailable
,
Status
:
"Service Unavailable"
,
ErrorMessage
:
"model is overloaded"
,
}
}
stream
:=
false
w
:=
createRequest
(
t
,
s
.
ChatHandler
,
api
.
ChatRequest
{
Model
:
"test"
,
Messages
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"Hello!"
},
},
Stream
:
&
stream
,
})
if
w
.
Code
!=
http
.
StatusServiceUnavailable
{
t
.
Errorf
(
"expected status 503, got %d"
,
w
.
Code
)
}
if
diff
:=
cmp
.
Diff
(
w
.
Body
.
String
(),
`{"error":"model is overloaded"}`
);
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
t
.
Run
(
"status error streaming"
,
func
(
t
*
testing
.
T
)
{
mock
.
CompletionFn
=
func
(
ctx
context
.
Context
,
r
llm
.
CompletionRequest
,
fn
func
(
r
llm
.
CompletionResponse
))
error
{
return
api
.
StatusError
{
StatusCode
:
http
.
StatusTooManyRequests
,
Status
:
"Too Many Requests"
,
ErrorMessage
:
"rate limit exceeded"
,
}
}
w
:=
createRequest
(
t
,
s
.
ChatHandler
,
api
.
ChatRequest
{
Model
:
"test"
,
Messages
:
[]
api
.
Message
{
{
Role
:
"user"
,
Content
:
"Hello!"
},
},
})
if
w
.
Code
!=
http
.
StatusTooManyRequests
{
t
.
Errorf
(
"expected status 429, got %d"
,
w
.
Code
)
}
if
diff
:=
cmp
.
Diff
(
w
.
Body
.
String
(),
`{"error":"rate limit exceeded"}`
);
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
}
}
func
TestGenerate
(
t
*
testing
.
T
)
{
func
TestGenerate
(
t
*
testing
.
T
)
{
...
@@ -983,6 +1035,55 @@ func TestGenerate(t *testing.T) {
...
@@ -983,6 +1035,55 @@ func TestGenerate(t *testing.T) {
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
}
})
})
t
.
Run
(
"status error non-streaming"
,
func
(
t
*
testing
.
T
)
{
mock
.
CompletionFn
=
func
(
ctx
context
.
Context
,
r
llm
.
CompletionRequest
,
fn
func
(
r
llm
.
CompletionResponse
))
error
{
return
api
.
StatusError
{
StatusCode
:
http
.
StatusServiceUnavailable
,
Status
:
"Service Unavailable"
,
ErrorMessage
:
"model is overloaded"
,
}
}
streamRequest
:=
false
w
:=
createRequest
(
t
,
s
.
GenerateHandler
,
api
.
GenerateRequest
{
Model
:
"test"
,
Prompt
:
"Hello!"
,
Stream
:
&
streamRequest
,
})
if
w
.
Code
!=
http
.
StatusServiceUnavailable
{
t
.
Errorf
(
"expected status 503, got %d"
,
w
.
Code
)
}
if
diff
:=
cmp
.
Diff
(
w
.
Body
.
String
(),
`{"error":"model is overloaded"}`
);
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
t
.
Run
(
"status error streaming"
,
func
(
t
*
testing
.
T
)
{
mock
.
CompletionFn
=
func
(
ctx
context
.
Context
,
r
llm
.
CompletionRequest
,
fn
func
(
r
llm
.
CompletionResponse
))
error
{
return
api
.
StatusError
{
StatusCode
:
http
.
StatusTooManyRequests
,
Status
:
"Too Many Requests"
,
ErrorMessage
:
"rate limit exceeded"
,
}
}
w
:=
createRequest
(
t
,
s
.
GenerateHandler
,
api
.
GenerateRequest
{
Model
:
"test"
,
Prompt
:
"Hello!"
,
Stream
:
&
stream
,
})
if
w
.
Code
!=
http
.
StatusTooManyRequests
{
t
.
Errorf
(
"expected status 429, got %d"
,
w
.
Code
)
}
if
diff
:=
cmp
.
Diff
(
w
.
Body
.
String
(),
`{"error":"rate limit exceeded"}`
);
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
}
}
func
TestChatWithPromptEndingInThinkTag
(
t
*
testing
.
T
)
{
func
TestChatWithPromptEndingInThinkTag
(
t
*
testing
.
T
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment