Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
0b3118e0
Unverified
Commit
0b3118e0
authored
Jan 03, 2024
by
Bruce MacDonald
Committed by
GitHub
Jan 03, 2024
Browse files
fix: relay request opts to loaded llm prediction (#1761)
parent
05face44
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
103 additions
and
68 deletions
+103
-68
llm/ext_server_common.go
llm/ext_server_common.go
+18
-18
llm/ext_server_default.go
llm/ext_server_default.go
+1
-1
llm/llama.go
llm/llama.go
+4
-3
llm/shim_ext_server.go
llm/shim_ext_server.go
+1
-1
server/routes.go
server/routes.go
+79
-45
No files found.
llm/ext_server_common.go
View file @
0b3118e0
...
...
@@ -153,7 +153,7 @@ func newExtServer(server extServer, model string, adapters, projectors []string,
return
server
,
nil
}
func
predict
(
llm
extServer
,
opts
api
.
Options
,
ctx
context
.
Context
,
predict
PredictOpts
,
fn
func
(
PredictResult
))
error
{
func
predict
(
ctx
context
.
Context
,
llm
extServer
,
predict
PredictOpts
,
fn
func
(
PredictResult
))
error
{
resp
:=
newExtServerResp
(
128
)
defer
freeExtServerResp
(
resp
)
var
imageData
[]
ImageData
...
...
@@ -167,23 +167,23 @@ func predict(llm extServer, opts api.Options, ctx context.Context, predict Predi
request
:=
map
[
string
]
any
{
"prompt"
:
predict
.
Prompt
,
"stream"
:
true
,
"n_predict"
:
opt
s
.
NumPredict
,
"n_keep"
:
opt
s
.
NumKeep
,
"temperature"
:
opt
s
.
Temperature
,
"top_k"
:
opt
s
.
TopK
,
"top_p"
:
opt
s
.
TopP
,
"tfs_z"
:
opt
s
.
TFSZ
,
"typical_p"
:
opt
s
.
TypicalP
,
"repeat_last_n"
:
opt
s
.
RepeatLastN
,
"repeat_penalty"
:
opt
s
.
RepeatPenalty
,
"presence_penalty"
:
opt
s
.
PresencePenalty
,
"frequency_penalty"
:
opt
s
.
FrequencyPenalty
,
"mirostat"
:
opt
s
.
Mirostat
,
"mirostat_tau"
:
opt
s
.
MirostatTau
,
"mirostat_eta"
:
opt
s
.
MirostatEta
,
"penalize_nl"
:
opt
s
.
PenalizeNewline
,
"seed"
:
opt
s
.
Seed
,
"stop"
:
opt
s
.
Stop
,
"n_predict"
:
predict
.
Option
s
.
NumPredict
,
"n_keep"
:
predict
.
Option
s
.
NumKeep
,
"temperature"
:
predict
.
Option
s
.
Temperature
,
"top_k"
:
predict
.
Option
s
.
TopK
,
"top_p"
:
predict
.
Option
s
.
TopP
,
"tfs_z"
:
predict
.
Option
s
.
TFSZ
,
"typical_p"
:
predict
.
Option
s
.
TypicalP
,
"repeat_last_n"
:
predict
.
Option
s
.
RepeatLastN
,
"repeat_penalty"
:
predict
.
Option
s
.
RepeatPenalty
,
"presence_penalty"
:
predict
.
Option
s
.
PresencePenalty
,
"frequency_penalty"
:
predict
.
Option
s
.
FrequencyPenalty
,
"mirostat"
:
predict
.
Option
s
.
Mirostat
,
"mirostat_tau"
:
predict
.
Option
s
.
MirostatTau
,
"mirostat_eta"
:
predict
.
Option
s
.
MirostatEta
,
"penalize_nl"
:
predict
.
Option
s
.
PenalizeNewline
,
"seed"
:
predict
.
Option
s
.
Seed
,
"stop"
:
predict
.
Option
s
.
Stop
,
"image_data"
:
imageData
,
"cache_prompt"
:
true
,
}
...
...
llm/ext_server_default.go
View file @
0b3118e0
...
...
@@ -60,7 +60,7 @@ func newDefaultExtServer(model string, adapters, projectors []string, numLayers
}
func
(
llm
*
llamaExtServer
)
Predict
(
ctx
context
.
Context
,
pred
PredictOpts
,
fn
func
(
PredictResult
))
error
{
return
predict
(
llm
,
llm
.
Options
,
ctx
,
pred
,
fn
)
return
predict
(
ctx
,
llm
,
pred
,
fn
)
}
func
(
llm
*
llamaExtServer
)
Encode
(
ctx
context
.
Context
,
prompt
string
)
([]
int
,
error
)
{
...
...
llm/llama.go
View file @
0b3118e0
...
...
@@ -166,9 +166,10 @@ const maxRetries = 3
const
retryDelay
=
1
*
time
.
Second
type
PredictOpts
struct
{
Prompt
string
Format
string
Images
[]
api
.
ImageData
Prompt
string
Format
string
Images
[]
api
.
ImageData
Options
api
.
Options
}
type
PredictResult
struct
{
...
...
llm/shim_ext_server.go
View file @
0b3118e0
...
...
@@ -92,7 +92,7 @@ func newDynamicShimExtServer(library, model string, adapters, projectors []strin
}
func
(
llm
*
shimExtServer
)
Predict
(
ctx
context
.
Context
,
pred
PredictOpts
,
fn
func
(
PredictResult
))
error
{
return
predict
(
llm
,
llm
.
options
,
ctx
,
pred
,
fn
)
return
predict
(
ctx
,
llm
,
pred
,
fn
)
}
func
(
llm
*
shimExtServer
)
Encode
(
ctx
context
.
Context
,
prompt
string
)
([]
int
,
error
)
{
...
...
server/routes.go
View file @
0b3118e0
...
...
@@ -64,24 +64,9 @@ var loaded struct {
var
defaultSessionDuration
=
5
*
time
.
Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func
load
(
c
*
gin
.
Context
,
modelName
string
,
reqOpts
map
[
string
]
interface
{},
sessionDuration
time
.
Duration
)
(
*
Model
,
error
)
{
model
,
err
:=
GetModel
(
modelName
)
if
err
!=
nil
{
return
nil
,
err
}
func
load
(
c
*
gin
.
Context
,
model
*
Model
,
opts
api
.
Options
,
sessionDuration
time
.
Duration
)
error
{
workDir
:=
c
.
GetString
(
"workDir"
)
opts
:=
api
.
DefaultOptions
()
if
err
:=
opts
.
FromMap
(
model
.
Options
);
err
!=
nil
{
log
.
Printf
(
"could not load model options: %v"
,
err
)
return
nil
,
err
}
if
err
:=
opts
.
FromMap
(
reqOpts
);
err
!=
nil
{
return
nil
,
err
}
needLoad
:=
loaded
.
runner
==
nil
||
// is there a model loaded?
loaded
.
ModelPath
!=
model
.
ModelPath
||
// has the base model changed?
!
reflect
.
DeepEqual
(
loaded
.
AdapterPaths
,
model
.
AdapterPaths
)
||
// have the adapters changed?
...
...
@@ -105,7 +90,7 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
err
=
fmt
.
Errorf
(
"%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`"
,
err
,
model
.
ShortName
)
}
return
nil
,
err
return
err
}
loaded
.
Model
=
model
...
...
@@ -135,7 +120,20 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
}
loaded
.
expireTimer
.
Reset
(
sessionDuration
)
return
model
,
nil
return
nil
}
func
modelOptions
(
model
*
Model
,
requestOpts
map
[
string
]
interface
{})
(
api
.
Options
,
error
)
{
opts
:=
api
.
DefaultOptions
()
if
err
:=
opts
.
FromMap
(
model
.
Options
);
err
!=
nil
{
return
api
.
Options
{},
err
}
if
err
:=
opts
.
FromMap
(
requestOpts
);
err
!=
nil
{
return
api
.
Options
{},
err
}
return
opts
,
nil
}
func
GenerateHandler
(
c
*
gin
.
Context
)
{
...
...
@@ -168,18 +166,30 @@ func GenerateHandler(c *gin.Context) {
return
}
sessionDuration
:=
defaultSessionDuration
model
,
err
:=
load
(
c
,
req
.
Model
,
req
.
Options
,
sessionDuration
)
model
,
err
:=
GetModel
(
req
.
Model
)
if
err
!=
nil
{
var
pErr
*
fs
.
PathError
switch
{
case
errors
.
As
(
err
,
&
pErr
)
:
if
errors
.
As
(
err
,
&
pErr
)
{
c
.
JSON
(
http
.
StatusNotFound
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"model '%s' not found, try pulling it first"
,
req
.
Model
)})
case
errors
.
Is
(
err
,
api
.
ErrInvalidOpts
)
:
return
}
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
opts
,
err
:=
modelOptions
(
model
,
req
.
Options
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
api
.
ErrInvalidOpts
)
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
default
:
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
sessionDuration
:=
defaultSessionDuration
if
err
:=
load
(
c
,
model
,
opts
,
sessionDuration
);
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
...
...
@@ -287,9 +297,10 @@ func GenerateHandler(c *gin.Context) {
// Start prediction
predictReq
:=
llm
.
PredictOpts
{
Prompt
:
prompt
,
Format
:
req
.
Format
,
Images
:
req
.
Images
,
Prompt
:
prompt
,
Format
:
req
.
Format
,
Images
:
req
.
Images
,
Options
:
opts
,
}
if
err
:=
loaded
.
runner
.
Predict
(
c
.
Request
.
Context
(),
predictReq
,
fn
);
err
!=
nil
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
...
...
@@ -347,18 +358,29 @@ func EmbeddingHandler(c *gin.Context) {
return
}
sessionDuration
:=
defaultSessionDuration
_
,
err
=
load
(
c
,
req
.
Model
,
req
.
Options
,
sessionDuration
)
model
,
err
:=
GetModel
(
req
.
Model
)
if
err
!=
nil
{
var
pErr
*
fs
.
PathError
switch
{
case
errors
.
As
(
err
,
&
pErr
)
:
if
errors
.
As
(
err
,
&
pErr
)
{
c
.
JSON
(
http
.
StatusNotFound
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"model '%s' not found, try pulling it first"
,
req
.
Model
)})
case
errors
.
Is
(
err
,
api
.
ErrInvalidOpts
)
:
return
}
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
opts
,
err
:=
modelOptions
(
model
,
req
.
Options
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
api
.
ErrInvalidOpts
)
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
default
:
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
sessionDuration
:=
defaultSessionDuration
if
err
:=
load
(
c
,
model
,
opts
,
sessionDuration
);
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
...
...
@@ -991,18 +1013,29 @@ func ChatHandler(c *gin.Context) {
return
}
sessionDuration
:=
defaultSessionDuration
model
,
err
:=
load
(
c
,
req
.
Model
,
req
.
Options
,
sessionDuration
)
model
,
err
:=
GetModel
(
req
.
Model
)
if
err
!=
nil
{
var
pErr
*
fs
.
PathError
switch
{
case
errors
.
As
(
err
,
&
pErr
)
:
if
errors
.
As
(
err
,
&
pErr
)
{
c
.
JSON
(
http
.
StatusNotFound
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"model '%s' not found, try pulling it first"
,
req
.
Model
)})
case
errors
.
Is
(
err
,
api
.
ErrInvalidOpts
)
:
return
}
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
opts
,
err
:=
modelOptions
(
model
,
req
.
Options
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
api
.
ErrInvalidOpts
)
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
default
:
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
sessionDuration
:=
defaultSessionDuration
if
err
:=
load
(
c
,
model
,
opts
,
sessionDuration
);
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
...
...
@@ -1053,9 +1086,10 @@ func ChatHandler(c *gin.Context) {
// Start prediction
predictReq
:=
llm
.
PredictOpts
{
Prompt
:
prompt
,
Format
:
req
.
Format
,
Images
:
images
,
Prompt
:
prompt
,
Format
:
req
.
Format
,
Images
:
images
,
Options
:
opts
,
}
if
err
:=
loaded
.
runner
.
Predict
(
c
.
Request
.
Context
(),
predictReq
,
fn
);
err
!=
nil
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment