Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
c863c6a9
Unverified
Commit
c863c6a9
authored
Apr 02, 2024
by
Daniel Hiltgen
Committed by
GitHub
Apr 02, 2024
Browse files
Merge pull request #3218 from dhiltgen/subprocess
Switch back to subprocessing for llama.cpp
parents
3b6a9154
1f11b525
Changes
42
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
66 deletions
+43
-66
server/routes.go
server/routes.go
+42
-40
server/routes_test.go
server/routes_test.go
+1
-26
No files found.
server/routes.go
View file @
c863c6a9
...
@@ -56,12 +56,13 @@ func init() {
...
@@ -56,12 +56,13 @@ func init() {
var
loaded
struct
{
var
loaded
struct
{
mu
sync
.
Mutex
mu
sync
.
Mutex
runner
llm
.
LLM
llama
*
llm
.
LlamaServer
expireAt
time
.
Time
expireTimer
*
time
.
Timer
expireTimer
*
time
.
Timer
*
Model
model
string
adapters
[]
string
projectors
[]
string
*
api
.
Options
*
api
.
Options
}
}
...
@@ -69,21 +70,28 @@ var defaultSessionDuration = 5 * time.Minute
...
@@ -69,21 +70,28 @@ var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func
load
(
c
*
gin
.
Context
,
model
*
Model
,
opts
*
api
.
Options
,
sessionDuration
time
.
Duration
)
error
{
func
load
(
c
*
gin
.
Context
,
model
*
Model
,
opts
*
api
.
Options
,
sessionDuration
time
.
Duration
)
error
{
needLoad
:=
loaded
.
runner
==
nil
||
// is there a model loaded?
ctx
,
cancel
:=
context
.
WithTimeout
(
c
,
10
*
time
.
Second
)
loaded
.
ModelPath
!=
model
.
ModelPath
||
// has the base model changed?
defer
cancel
()
!
reflect
.
DeepEqual
(
loaded
.
AdapterPaths
,
model
.
AdapterPaths
)
||
// have the adapters changed?
!
reflect
.
DeepEqual
(
loaded
.
Options
.
Runner
,
opts
.
Runner
)
// have the runner options changed?
needLoad
:=
loaded
.
llama
==
nil
||
// is there a model loaded?
loaded
.
model
!=
model
.
ModelPath
||
// has the base model changed?
!
reflect
.
DeepEqual
(
loaded
.
adapters
,
model
.
AdapterPaths
)
||
// have the adapters changed?
!
reflect
.
DeepEqual
(
loaded
.
projectors
,
model
.
ProjectorPaths
)
||
// have the adapters changed?
!
reflect
.
DeepEqual
(
loaded
.
Options
.
Runner
,
opts
.
Runner
)
||
// have the runner options changed?
loaded
.
llama
.
Ping
(
ctx
)
!=
nil
if
needLoad
{
if
needLoad
{
if
loaded
.
runner
!=
nil
{
if
loaded
.
llama
!=
nil
{
slog
.
Info
(
"changing loaded model"
)
slog
.
Info
(
"changing loaded model"
)
loaded
.
runner
.
Close
()
loaded
.
llama
.
Close
()
loaded
.
runner
=
nil
loaded
.
llama
=
nil
loaded
.
Model
=
nil
loaded
.
model
=
""
loaded
.
adapters
=
nil
loaded
.
projectors
=
nil
loaded
.
Options
=
nil
loaded
.
Options
=
nil
}
}
ll
mRunner
,
err
:=
llm
.
New
(
model
.
ModelPath
,
model
.
AdapterPaths
,
model
.
ProjectorPaths
,
opts
)
ll
ama
,
err
:=
llm
.
New
LlamaServer
(
model
.
ModelPath
,
model
.
AdapterPaths
,
model
.
ProjectorPaths
,
opts
)
if
err
!=
nil
{
if
err
!=
nil
{
// some older models are not compatible with newer versions of llama.cpp
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// show a generalized compatibility error until there is a better way to
...
@@ -95,28 +103,26 @@ func load(c *gin.Context, model *Model, opts *api.Options, sessionDuration time.
...
@@ -95,28 +103,26 @@ func load(c *gin.Context, model *Model, opts *api.Options, sessionDuration time.
return
err
return
err
}
}
loaded
.
Model
=
model
loaded
.
model
=
model
.
ModelPath
loaded
.
runner
=
llmRunner
loaded
.
adapters
=
model
.
AdapterPaths
loaded
.
projectors
=
model
.
ProjectorPaths
loaded
.
llama
=
llama
loaded
.
Options
=
opts
loaded
.
Options
=
opts
}
}
loaded
.
expireAt
=
time
.
Now
()
.
Add
(
sessionDuration
)
if
loaded
.
expireTimer
==
nil
{
if
loaded
.
expireTimer
==
nil
{
loaded
.
expireTimer
=
time
.
AfterFunc
(
sessionDuration
,
func
()
{
loaded
.
expireTimer
=
time
.
AfterFunc
(
sessionDuration
,
func
()
{
loaded
.
mu
.
Lock
()
loaded
.
mu
.
Lock
()
defer
loaded
.
mu
.
Unlock
()
defer
loaded
.
mu
.
Unlock
()
if
time
.
Now
()
.
Before
(
loaded
.
expireAt
)
{
if
loaded
.
llama
!=
nil
{
return
loaded
.
llama
.
Close
()
}
}
if
loaded
.
runner
!=
nil
{
loaded
.
llama
=
nil
loaded
.
runner
.
Close
()
loaded
.
model
=
""
}
loaded
.
adapters
=
nil
loaded
.
projectors
=
nil
loaded
.
runner
=
nil
loaded
.
Model
=
nil
loaded
.
Options
=
nil
loaded
.
Options
=
nil
})
})
}
}
...
@@ -265,7 +271,7 @@ func GenerateHandler(c *gin.Context) {
...
@@ -265,7 +271,7 @@ func GenerateHandler(c *gin.Context) {
sb
.
Reset
()
sb
.
Reset
()
if
req
.
Context
!=
nil
{
if
req
.
Context
!=
nil
{
prev
,
err
:=
loaded
.
runner
.
Decod
e
(
c
.
Request
.
Context
(),
req
.
Context
)
prev
,
err
:=
loaded
.
llama
.
Detokeniz
e
(
c
.
Request
.
Context
(),
req
.
Context
)
if
err
!=
nil
{
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
return
...
@@ -286,9 +292,8 @@ func GenerateHandler(c *gin.Context) {
...
@@ -286,9 +292,8 @@ func GenerateHandler(c *gin.Context) {
go
func
()
{
go
func
()
{
defer
close
(
ch
)
defer
close
(
ch
)
fn
:=
func
(
r
llm
.
PredictResult
)
{
fn
:=
func
(
r
llm
.
CompletionResponse
)
{
// Update model expiration
// Update model expiration
loaded
.
expireAt
=
time
.
Now
()
.
Add
(
sessionDuration
)
loaded
.
expireTimer
.
Reset
(
sessionDuration
)
loaded
.
expireTimer
.
Reset
(
sessionDuration
)
// Build up the full response
// Build up the full response
...
@@ -322,7 +327,7 @@ func GenerateHandler(c *gin.Context) {
...
@@ -322,7 +327,7 @@ func GenerateHandler(c *gin.Context) {
}
}
// TODO (jmorganca): encode() should not strip special tokens
// TODO (jmorganca): encode() should not strip special tokens
tokens
,
err
:=
loaded
.
runner
.
Encod
e
(
c
.
Request
.
Context
(),
p
)
tokens
,
err
:=
loaded
.
llama
.
Tokeniz
e
(
c
.
Request
.
Context
(),
p
)
if
err
!=
nil
{
if
err
!=
nil
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
return
return
...
@@ -344,13 +349,13 @@ func GenerateHandler(c *gin.Context) {
...
@@ -344,13 +349,13 @@ func GenerateHandler(c *gin.Context) {
}
}
// Start prediction
// Start prediction
predictR
eq
:=
llm
.
PredictOpts
{
r
eq
:=
llm
.
CompletionRequest
{
Prompt
:
prompt
,
Prompt
:
prompt
,
Format
:
req
.
Format
,
Format
:
req
.
Format
,
Images
:
images
,
Images
:
images
,
Options
:
opts
,
Options
:
opts
,
}
}
if
err
:=
loaded
.
runner
.
Predict
(
c
.
Request
.
Context
(),
predictR
eq
,
fn
);
err
!=
nil
{
if
err
:=
loaded
.
llama
.
Completion
(
c
.
Request
.
Context
(),
r
eq
,
fn
);
err
!=
nil
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
}
}
}()
}()
...
@@ -471,7 +476,7 @@ func EmbeddingsHandler(c *gin.Context) {
...
@@ -471,7 +476,7 @@ func EmbeddingsHandler(c *gin.Context) {
return
return
}
}
embedding
,
err
:=
loaded
.
runner
.
Embedding
(
c
.
Request
.
Context
(),
req
.
Prompt
)
embedding
,
err
:=
loaded
.
llama
.
Embedding
(
c
.
Request
.
Context
(),
req
.
Prompt
)
if
err
!=
nil
{
if
err
!=
nil
{
slog
.
Info
(
fmt
.
Sprintf
(
"embedding generation failed: %v"
,
err
))
slog
.
Info
(
fmt
.
Sprintf
(
"embedding generation failed: %v"
,
err
))
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"failed to generate embedding"
})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
"failed to generate embedding"
})
...
@@ -1123,8 +1128,8 @@ func Serve(ln net.Listener) error {
...
@@ -1123,8 +1128,8 @@ func Serve(ln net.Listener) error {
signal
.
Notify
(
signals
,
syscall
.
SIGINT
,
syscall
.
SIGTERM
)
signal
.
Notify
(
signals
,
syscall
.
SIGINT
,
syscall
.
SIGTERM
)
go
func
()
{
go
func
()
{
<-
signals
<-
signals
if
loaded
.
runner
!=
nil
{
if
loaded
.
llama
!=
nil
{
loaded
.
runner
.
Close
()
loaded
.
llama
.
Close
()
}
}
gpu
.
Cleanup
()
gpu
.
Cleanup
()
os
.
Exit
(
0
)
os
.
Exit
(
0
)
...
@@ -1196,7 +1201,7 @@ func streamResponse(c *gin.Context, ch chan any) {
...
@@ -1196,7 +1201,7 @@ func streamResponse(c *gin.Context, ch chan any) {
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func
chatPrompt
(
ctx
context
.
Context
,
template
string
,
messages
[]
api
.
Message
,
numCtx
int
)
(
string
,
error
)
{
func
chatPrompt
(
ctx
context
.
Context
,
template
string
,
messages
[]
api
.
Message
,
numCtx
int
)
(
string
,
error
)
{
encode
:=
func
(
s
string
)
([]
int
,
error
)
{
encode
:=
func
(
s
string
)
([]
int
,
error
)
{
return
loaded
.
runner
.
Encod
e
(
ctx
,
s
)
return
loaded
.
llama
.
Tokeniz
e
(
ctx
,
s
)
}
}
prompt
,
err
:=
ChatPrompt
(
template
,
messages
,
numCtx
,
encode
)
prompt
,
err
:=
ChatPrompt
(
template
,
messages
,
numCtx
,
encode
)
...
@@ -1326,9 +1331,8 @@ func ChatHandler(c *gin.Context) {
...
@@ -1326,9 +1331,8 @@ func ChatHandler(c *gin.Context) {
go
func
()
{
go
func
()
{
defer
close
(
ch
)
defer
close
(
ch
)
fn
:=
func
(
r
llm
.
PredictResult
)
{
fn
:=
func
(
r
llm
.
CompletionResponse
)
{
// Update model expiration
// Update model expiration
loaded
.
expireAt
=
time
.
Now
()
.
Add
(
sessionDuration
)
loaded
.
expireTimer
.
Reset
(
sessionDuration
)
loaded
.
expireTimer
.
Reset
(
sessionDuration
)
resp
:=
api
.
ChatResponse
{
resp
:=
api
.
ChatResponse
{
...
@@ -1352,14 +1356,12 @@ func ChatHandler(c *gin.Context) {
...
@@ -1352,14 +1356,12 @@ func ChatHandler(c *gin.Context) {
ch
<-
resp
ch
<-
resp
}
}
// Start prediction
if
err
:=
loaded
.
llama
.
Completion
(
c
.
Request
.
Context
(),
llm
.
CompletionRequest
{
predictReq
:=
llm
.
PredictOpts
{
Prompt
:
prompt
,
Prompt
:
prompt
,
Format
:
req
.
Format
,
Format
:
req
.
Format
,
Images
:
images
,
Images
:
images
,
Options
:
opts
,
Options
:
opts
,
}
},
fn
);
err
!=
nil
{
if
err
:=
loaded
.
runner
.
Predict
(
c
.
Request
.
Context
(),
predictReq
,
fn
);
err
!=
nil
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
}
}
}()
}()
...
...
server/routes_test.go
View file @
c863c6a9
...
@@ -17,7 +17,6 @@ import (
...
@@ -17,7 +17,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/version"
)
)
...
@@ -211,7 +210,7 @@ func Test_Routes(t *testing.T) {
...
@@ -211,7 +210,7 @@ func Test_Routes(t *testing.T) {
},
},
}
}
s
:=
Server
{}
s
:=
&
Server
{}
router
:=
s
.
GenerateRoutes
()
router
:=
s
.
GenerateRoutes
()
httpSrv
:=
httptest
.
NewServer
(
router
)
httpSrv
:=
httptest
.
NewServer
(
router
)
...
@@ -242,27 +241,3 @@ func Test_Routes(t *testing.T) {
...
@@ -242,27 +241,3 @@ func Test_Routes(t *testing.T) {
}
}
}
}
type
MockLLM
struct
{
encoding
[]
int
}
func
(
llm
*
MockLLM
)
Predict
(
ctx
context
.
Context
,
pred
llm
.
PredictOpts
,
fn
func
(
llm
.
PredictResult
))
error
{
return
nil
}
func
(
llm
*
MockLLM
)
Encode
(
ctx
context
.
Context
,
prompt
string
)
([]
int
,
error
)
{
return
llm
.
encoding
,
nil
}
func
(
llm
*
MockLLM
)
Decode
(
ctx
context
.
Context
,
tokens
[]
int
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
llm
*
MockLLM
)
Embedding
(
ctx
context
.
Context
,
input
string
)
([]
float64
,
error
)
{
return
[]
float64
{},
nil
}
func
(
llm
*
MockLLM
)
Close
()
{
// do nothing
}
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment