Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
8072e205
Unverified
Commit
8072e205
authored
Jul 03, 2024
by
Daniel Hiltgen
Committed by
GitHub
Jul 03, 2024
Browse files
Merge pull request #5447 from dhiltgen/fix_keepalive
Only set default keep_alive on initial model load
parents
ccd77858
955f2a4e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
70 additions
and
71 deletions
+70
-71
envconfig/config.go
envconfig/config.go
+29
-2
envconfig/config_test.go
envconfig/config_test.go
+17
-0
server/routes.go
server/routes.go
+3
-54
server/sched.go
server/sched.go
+10
-4
server/sched_test.go
server/sched_test.go
+11
-11
No files found.
envconfig/config.go
View file @
8072e205
...
...
@@ -4,12 +4,14 @@ import (
"errors"
"fmt"
"log/slog"
"math"
"net"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
)
type
OllamaHost
struct
{
...
...
@@ -34,7 +36,7 @@ var (
// Set via OLLAMA_HOST in the environment
Host
*
OllamaHost
// Set via OLLAMA_KEEP_ALIVE in the environment
KeepAlive
string
KeepAlive
time
.
Duration
// Set via OLLAMA_LLM_LIBRARY in the environment
LLMLibrary
string
// Set via OLLAMA_MAX_LOADED_MODELS in the environment
...
...
@@ -132,6 +134,7 @@ func init() {
NumParallel
=
0
// Autoselect
MaxRunners
=
0
// Autoselect
MaxQueuedRequests
=
512
KeepAlive
=
5
*
time
.
Minute
LoadConfig
()
}
...
...
@@ -266,7 +269,10 @@ func LoadConfig() {
}
}
KeepAlive
=
clean
(
"OLLAMA_KEEP_ALIVE"
)
ka
:=
clean
(
"OLLAMA_KEEP_ALIVE"
)
if
ka
!=
""
{
loadKeepAlive
(
ka
)
}
var
err
error
ModelsDir
,
err
=
getModelsDir
()
...
...
@@ -344,3 +350,24 @@ func getOllamaHost() (*OllamaHost, error) {
Port
:
port
,
},
nil
}
func
loadKeepAlive
(
ka
string
)
{
v
,
err
:=
strconv
.
Atoi
(
ka
)
if
err
!=
nil
{
d
,
err
:=
time
.
ParseDuration
(
ka
)
if
err
==
nil
{
if
d
<
0
{
KeepAlive
=
time
.
Duration
(
math
.
MaxInt64
)
}
else
{
KeepAlive
=
d
}
}
}
else
{
d
:=
time
.
Duration
(
v
)
*
time
.
Second
if
d
<
0
{
KeepAlive
=
time
.
Duration
(
math
.
MaxInt64
)
}
else
{
KeepAlive
=
d
}
}
}
envconfig/config_test.go
View file @
8072e205
...
...
@@ -2,8 +2,10 @@ package envconfig
import
(
"fmt"
"math"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
...
...
@@ -23,6 +25,21 @@ func TestConfig(t *testing.T) {
t
.
Setenv
(
"OLLAMA_FLASH_ATTENTION"
,
"1"
)
LoadConfig
()
require
.
True
(
t
,
FlashAttention
)
t
.
Setenv
(
"OLLAMA_KEEP_ALIVE"
,
""
)
LoadConfig
()
require
.
Equal
(
t
,
5
*
time
.
Minute
,
KeepAlive
)
t
.
Setenv
(
"OLLAMA_KEEP_ALIVE"
,
"3"
)
LoadConfig
()
require
.
Equal
(
t
,
3
*
time
.
Second
,
KeepAlive
)
t
.
Setenv
(
"OLLAMA_KEEP_ALIVE"
,
"1h"
)
LoadConfig
()
require
.
Equal
(
t
,
1
*
time
.
Hour
,
KeepAlive
)
t
.
Setenv
(
"OLLAMA_KEEP_ALIVE"
,
"-1s"
)
LoadConfig
()
require
.
Equal
(
t
,
time
.
Duration
(
math
.
MaxInt64
),
KeepAlive
)
t
.
Setenv
(
"OLLAMA_KEEP_ALIVE"
,
"-1"
)
LoadConfig
()
require
.
Equal
(
t
,
time
.
Duration
(
math
.
MaxInt64
),
KeepAlive
)
}
func
TestClientFromEnvironment
(
t
*
testing
.
T
)
{
...
...
server/routes.go
View file @
8072e205
...
...
@@ -9,7 +9,6 @@ import (
"io"
"io/fs"
"log/slog"
"math"
"net"
"net/http"
"net/netip"
...
...
@@ -17,7 +16,6 @@ import (
"os/signal"
"path/filepath"
"slices"
"strconv"
"strings"
"syscall"
"time"
...
...
@@ -56,8 +54,6 @@ func init() {
gin
.
SetMode
(
mode
)
}
var
defaultSessionDuration
=
5
*
time
.
Minute
func
modelOptions
(
model
*
Model
,
requestOpts
map
[
string
]
interface
{})
(
api
.
Options
,
error
)
{
opts
:=
api
.
DefaultOptions
()
if
err
:=
opts
.
FromMap
(
model
.
Options
);
err
!=
nil
{
...
...
@@ -133,14 +129,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
var
sessionDuration
time
.
Duration
if
req
.
KeepAlive
==
nil
{
sessionDuration
=
getDefaultSessionDuration
()
}
else
{
sessionDuration
=
req
.
KeepAlive
.
Duration
}
rCh
,
eCh
:=
s
.
sched
.
GetRunner
(
c
.
Request
.
Context
(),
model
,
opts
,
sessionDuration
)
rCh
,
eCh
:=
s
.
sched
.
GetRunner
(
c
.
Request
.
Context
(),
model
,
opts
,
req
.
KeepAlive
)
var
runner
*
runnerRef
select
{
case
runner
=
<-
rCh
:
...
...
@@ -320,32 +309,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
streamResponse
(
c
,
ch
)
}
func
getDefaultSessionDuration
()
time
.
Duration
{
if
envconfig
.
KeepAlive
!=
""
{
v
,
err
:=
strconv
.
Atoi
(
envconfig
.
KeepAlive
)
if
err
!=
nil
{
d
,
err
:=
time
.
ParseDuration
(
envconfig
.
KeepAlive
)
if
err
!=
nil
{
return
defaultSessionDuration
}
if
d
<
0
{
return
time
.
Duration
(
math
.
MaxInt64
)
}
return
d
}
d
:=
time
.
Duration
(
v
)
*
time
.
Second
if
d
<
0
{
return
time
.
Duration
(
math
.
MaxInt64
)
}
return
d
}
return
defaultSessionDuration
}
func
(
s
*
Server
)
EmbeddingsHandler
(
c
*
gin
.
Context
)
{
var
req
api
.
EmbeddingRequest
err
:=
c
.
ShouldBindJSON
(
&
req
)
...
...
@@ -380,14 +343,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
var
sessionDuration
time
.
Duration
if
req
.
KeepAlive
==
nil
{
sessionDuration
=
getDefaultSessionDuration
()
}
else
{
sessionDuration
=
req
.
KeepAlive
.
Duration
}
rCh
,
eCh
:=
s
.
sched
.
GetRunner
(
c
.
Request
.
Context
(),
model
,
opts
,
sessionDuration
)
rCh
,
eCh
:=
s
.
sched
.
GetRunner
(
c
.
Request
.
Context
(),
model
,
opts
,
req
.
KeepAlive
)
var
runner
*
runnerRef
select
{
case
runner
=
<-
rCh
:
...
...
@@ -1318,14 +1274,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
var
sessionDuration
time
.
Duration
if
req
.
KeepAlive
==
nil
{
sessionDuration
=
getDefaultSessionDuration
()
}
else
{
sessionDuration
=
req
.
KeepAlive
.
Duration
}
rCh
,
eCh
:=
s
.
sched
.
GetRunner
(
c
.
Request
.
Context
(),
model
,
opts
,
sessionDuration
)
rCh
,
eCh
:=
s
.
sched
.
GetRunner
(
c
.
Request
.
Context
(),
model
,
opts
,
req
.
KeepAlive
)
var
runner
*
runnerRef
select
{
case
runner
=
<-
rCh
:
...
...
server/sched.go
View file @
8072e205
...
...
@@ -24,7 +24,7 @@ type LlmRequest struct {
model
*
Model
opts
api
.
Options
origNumCtx
int
// Track the initial ctx request
sessionDuration
time
.
Duration
sessionDuration
*
api
.
Duration
successCh
chan
*
runnerRef
errCh
chan
error
schedAttempts
uint
...
...
@@ -75,7 +75,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
}
// context must be canceled to decrement ref count and release the runner
func
(
s
*
Scheduler
)
GetRunner
(
c
context
.
Context
,
model
*
Model
,
opts
api
.
Options
,
sessionDuration
time
.
Duration
)
(
chan
*
runnerRef
,
chan
error
)
{
func
(
s
*
Scheduler
)
GetRunner
(
c
context
.
Context
,
model
*
Model
,
opts
api
.
Options
,
sessionDuration
*
api
.
Duration
)
(
chan
*
runnerRef
,
chan
error
)
{
if
opts
.
NumCtx
<
4
{
opts
.
NumCtx
=
4
}
...
...
@@ -389,7 +389,9 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
runner
.
expireTimer
.
Stop
()
runner
.
expireTimer
=
nil
}
runner
.
sessionDuration
=
pending
.
sessionDuration
if
pending
.
sessionDuration
!=
nil
{
runner
.
sessionDuration
=
pending
.
sessionDuration
.
Duration
}
pending
.
successCh
<-
runner
go
func
()
{
<-
pending
.
ctx
.
Done
()
...
...
@@ -402,6 +404,10 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
if
numParallel
<
1
{
numParallel
=
1
}
sessionDuration
:=
envconfig
.
KeepAlive
if
req
.
sessionDuration
!=
nil
{
sessionDuration
=
req
.
sessionDuration
.
Duration
}
llama
,
err
:=
s
.
newServerFn
(
gpus
,
req
.
model
.
ModelPath
,
ggml
,
req
.
model
.
AdapterPaths
,
req
.
model
.
ProjectorPaths
,
req
.
opts
,
numParallel
)
if
err
!=
nil
{
// some older models are not compatible with newer versions of llama.cpp
...
...
@@ -419,7 +425,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
modelPath
:
req
.
model
.
ModelPath
,
llama
:
llama
,
Options
:
&
req
.
opts
,
sessionDuration
:
req
.
sessionDuration
,
sessionDuration
:
sessionDuration
,
gpus
:
gpus
,
estimatedVRAM
:
llama
.
EstimatedVRAM
(),
estimatedTotal
:
llama
.
EstimatedTotal
(),
...
...
server/sched_test.go
View file @
8072e205
...
...
@@ -44,7 +44,7 @@ func TestLoad(t *testing.T) {
opts
:
api
.
DefaultOptions
(),
successCh
:
make
(
chan
*
runnerRef
,
1
),
errCh
:
make
(
chan
error
,
1
),
sessionDuration
:
2
,
sessionDuration
:
&
api
.
Duration
{
Duration
:
2
*
time
.
Second
}
,
}
// Fail to load model first
s
.
newServerFn
=
func
(
gpus
gpu
.
GpuInfoList
,
model
string
,
ggml
*
llm
.
GGML
,
adapters
[]
string
,
projectors
[]
string
,
opts
api
.
Options
,
numParallel
int
)
(
llm
.
LlamaServer
,
error
)
{
...
...
@@ -142,7 +142,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
ctx
:
scenario
.
ctx
,
model
:
model
,
opts
:
api
.
DefaultOptions
(),
sessionDuration
:
5
*
time
.
Millisecond
,
sessionDuration
:
&
api
.
Duration
{
Duration
:
5
*
time
.
Millisecond
}
,
successCh
:
make
(
chan
*
runnerRef
,
1
),
errCh
:
make
(
chan
error
,
1
),
}
...
...
@@ -156,18 +156,18 @@ func TestRequests(t *testing.T) {
// Same model, same request
scenario1a
:=
newScenario
(
t
,
ctx
,
"ollama-model-1"
,
10
)
scenario1a
.
req
.
sessionDuration
=
5
*
time
.
Millisecond
scenario1a
.
req
.
sessionDuration
=
&
api
.
Duration
{
Duration
:
5
*
time
.
Millisecond
}
scenario1b
:=
newScenario
(
t
,
ctx
,
"ollama-model-1"
,
11
)
scenario1b
.
req
.
model
=
scenario1a
.
req
.
model
scenario1b
.
ggml
=
scenario1a
.
ggml
scenario1b
.
req
.
sessionDuration
=
0
scenario1b
.
req
.
sessionDuration
=
&
api
.
Duration
{
Duration
:
0
}
// simple reload of same model
scenario2a
:=
newScenario
(
t
,
ctx
,
"ollama-model-1"
,
20
)
tmpModel
:=
*
scenario1a
.
req
.
model
scenario2a
.
req
.
model
=
&
tmpModel
scenario2a
.
ggml
=
scenario1a
.
ggml
scenario2a
.
req
.
sessionDuration
=
5
*
time
.
Millisecond
scenario2a
.
req
.
sessionDuration
=
&
api
.
Duration
{
Duration
:
5
*
time
.
Millisecond
}
// Multiple loaded models
scenario3a
:=
newScenario
(
t
,
ctx
,
"ollama-model-3a"
,
1
*
format
.
GigaByte
)
...
...
@@ -318,11 +318,11 @@ func TestGetRunner(t *testing.T) {
defer
done
()
scenario1a
:=
newScenario
(
t
,
ctx
,
"ollama-model-1a"
,
10
)
scenario1a
.
req
.
sessionDuration
=
0
scenario1a
.
req
.
sessionDuration
=
&
api
.
Duration
{
Duration
:
0
}
scenario1b
:=
newScenario
(
t
,
ctx
,
"ollama-model-1b"
,
10
)
scenario1b
.
req
.
sessionDuration
=
0
scenario1b
.
req
.
sessionDuration
=
&
api
.
Duration
{
Duration
:
0
}
scenario1c
:=
newScenario
(
t
,
ctx
,
"ollama-model-1c"
,
10
)
scenario1c
.
req
.
sessionDuration
=
0
scenario1c
.
req
.
sessionDuration
=
&
api
.
Duration
{
Duration
:
0
}
envconfig
.
MaxQueuedRequests
=
1
s
:=
InitScheduler
(
ctx
)
s
.
getGpuFn
=
func
()
gpu
.
GpuInfoList
{
...
...
@@ -402,7 +402,7 @@ func TestPrematureExpired(t *testing.T) {
case
<-
ctx
.
Done
()
:
t
.
Fatal
(
"timeout"
)
}
time
.
Sleep
(
scenario1a
.
req
.
sessionDuration
)
time
.
Sleep
(
scenario1a
.
req
.
sessionDuration
.
Duration
)
scenario1a
.
ctxDone
()
time
.
Sleep
(
20
*
time
.
Millisecond
)
require
.
LessOrEqual
(
t
,
len
(
s
.
finishedReqCh
),
1
)
...
...
@@ -423,7 +423,7 @@ func TestUseLoadedRunner(t *testing.T) {
ctx
:
ctx
,
opts
:
api
.
DefaultOptions
(),
successCh
:
make
(
chan
*
runnerRef
,
1
),
sessionDuration
:
2
,
sessionDuration
:
&
api
.
Duration
{
Duration
:
2
}
,
}
finished
:=
make
(
chan
*
LlmRequest
)
llm1
:=
&
mockLlm
{
estimatedVRAMByGPU
:
map
[
string
]
uint64
{}}
...
...
@@ -614,7 +614,7 @@ func TestAlreadyCanceled(t *testing.T) {
dctx
,
done2
:=
context
.
WithCancel
(
ctx
)
done2
()
scenario1a
:=
newScenario
(
t
,
dctx
,
"ollama-model-1"
,
10
)
scenario1a
.
req
.
sessionDuration
=
0
scenario1a
.
req
.
sessionDuration
=
&
api
.
Duration
{
Duration
:
0
}
s
:=
InitScheduler
(
ctx
)
slog
.
Info
(
"scenario1a"
)
s
.
pendingReqCh
<-
scenario1a
.
req
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment