Commit 955f2a4e authored by Daniel Hiltgen's avatar Daniel Hiltgen
Browse files

Only set default keep_alive on initial model load

This change fixes the handling of keep_alive so that if client
request omits the setting, we only set this on initial load.  Once
the model is loaded, if new requests leave this unset, we'll keep
whatever keep_alive was there.
parent ccd77858
...@@ -4,12 +4,14 @@ import ( ...@@ -4,12 +4,14 @@ import (
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
"math"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"time"
) )
type OllamaHost struct { type OllamaHost struct {
...@@ -34,7 +36,7 @@ var ( ...@@ -34,7 +36,7 @@ var (
// Set via OLLAMA_HOST in the environment // Set via OLLAMA_HOST in the environment
Host *OllamaHost Host *OllamaHost
// Set via OLLAMA_KEEP_ALIVE in the environment // Set via OLLAMA_KEEP_ALIVE in the environment
KeepAlive string KeepAlive time.Duration
// Set via OLLAMA_LLM_LIBRARY in the environment // Set via OLLAMA_LLM_LIBRARY in the environment
LLMLibrary string LLMLibrary string
// Set via OLLAMA_MAX_LOADED_MODELS in the environment // Set via OLLAMA_MAX_LOADED_MODELS in the environment
...@@ -132,6 +134,7 @@ func init() { ...@@ -132,6 +134,7 @@ func init() {
NumParallel = 0 // Autoselect NumParallel = 0 // Autoselect
MaxRunners = 0 // Autoselect MaxRunners = 0 // Autoselect
MaxQueuedRequests = 512 MaxQueuedRequests = 512
KeepAlive = 5 * time.Minute
LoadConfig() LoadConfig()
} }
...@@ -266,7 +269,10 @@ func LoadConfig() { ...@@ -266,7 +269,10 @@ func LoadConfig() {
} }
} }
KeepAlive = clean("OLLAMA_KEEP_ALIVE") ka := clean("OLLAMA_KEEP_ALIVE")
if ka != "" {
loadKeepAlive(ka)
}
var err error var err error
ModelsDir, err = getModelsDir() ModelsDir, err = getModelsDir()
...@@ -344,3 +350,24 @@ func getOllamaHost() (*OllamaHost, error) { ...@@ -344,3 +350,24 @@ func getOllamaHost() (*OllamaHost, error) {
Port: port, Port: port,
}, nil }, nil
} }
func loadKeepAlive(ka string) {
v, err := strconv.Atoi(ka)
if err != nil {
d, err := time.ParseDuration(ka)
if err == nil {
if d < 0 {
KeepAlive = time.Duration(math.MaxInt64)
} else {
KeepAlive = d
}
}
} else {
d := time.Duration(v) * time.Second
if d < 0 {
KeepAlive = time.Duration(math.MaxInt64)
} else {
KeepAlive = d
}
}
}
...@@ -2,8 +2,10 @@ package envconfig ...@@ -2,8 +2,10 @@ package envconfig
import ( import (
"fmt" "fmt"
"math"
"net" "net"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -23,6 +25,21 @@ func TestConfig(t *testing.T) { ...@@ -23,6 +25,21 @@ func TestConfig(t *testing.T) {
t.Setenv("OLLAMA_FLASH_ATTENTION", "1") t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
LoadConfig() LoadConfig()
require.True(t, FlashAttention) require.True(t, FlashAttention)
t.Setenv("OLLAMA_KEEP_ALIVE", "")
LoadConfig()
require.Equal(t, 5*time.Minute, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "3")
LoadConfig()
require.Equal(t, 3*time.Second, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "1h")
LoadConfig()
require.Equal(t, 1*time.Hour, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "-1s")
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "-1")
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
} }
func TestClientFromEnvironment(t *testing.T) { func TestClientFromEnvironment(t *testing.T) {
......
...@@ -9,7 +9,6 @@ import ( ...@@ -9,7 +9,6 @@ import (
"io" "io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"math"
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
...@@ -17,7 +16,6 @@ import ( ...@@ -17,7 +16,6 @@ import (
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"slices" "slices"
"strconv"
"strings" "strings"
"syscall" "syscall"
"time" "time"
...@@ -56,8 +54,6 @@ func init() { ...@@ -56,8 +54,6 @@ func init() {
gin.SetMode(mode) gin.SetMode(mode)
} }
var defaultSessionDuration = 5 * time.Minute
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions() opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil { if err := opts.FromMap(model.Options); err != nil {
...@@ -133,14 +129,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -133,14 +129,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
var sessionDuration time.Duration rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef var runner *runnerRef
select { select {
case runner = <-rCh: case runner = <-rCh:
...@@ -320,32 +309,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -320,32 +309,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func getDefaultSessionDuration() time.Duration {
if envconfig.KeepAlive != "" {
v, err := strconv.Atoi(envconfig.KeepAlive)
if err != nil {
d, err := time.ParseDuration(envconfig.KeepAlive)
if err != nil {
return defaultSessionDuration
}
if d < 0 {
return time.Duration(math.MaxInt64)
}
return d
}
d := time.Duration(v) * time.Second
if d < 0 {
return time.Duration(math.MaxInt64)
}
return d
}
return defaultSessionDuration
}
func (s *Server) EmbeddingsHandler(c *gin.Context) { func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
...@@ -380,14 +343,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { ...@@ -380,14 +343,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
var sessionDuration time.Duration rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef var runner *runnerRef
select { select {
case runner = <-rCh: case runner = <-rCh:
...@@ -1318,14 +1274,7 @@ func (s *Server) ChatHandler(c *gin.Context) { ...@@ -1318,14 +1274,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
var sessionDuration time.Duration rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef var runner *runnerRef
select { select {
case runner = <-rCh: case runner = <-rCh:
......
...@@ -24,7 +24,7 @@ type LlmRequest struct { ...@@ -24,7 +24,7 @@ type LlmRequest struct {
model *Model model *Model
opts api.Options opts api.Options
origNumCtx int // Track the initial ctx request origNumCtx int // Track the initial ctx request
sessionDuration time.Duration sessionDuration *api.Duration
successCh chan *runnerRef successCh chan *runnerRef
errCh chan error errCh chan error
schedAttempts uint schedAttempts uint
...@@ -75,7 +75,7 @@ func InitScheduler(ctx context.Context) *Scheduler { ...@@ -75,7 +75,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
} }
// context must be canceled to decrement ref count and release the runner // context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) { func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 { if opts.NumCtx < 4 {
opts.NumCtx = 4 opts.NumCtx = 4
} }
...@@ -389,7 +389,9 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm ...@@ -389,7 +389,9 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
runner.expireTimer.Stop() runner.expireTimer.Stop()
runner.expireTimer = nil runner.expireTimer = nil
} }
runner.sessionDuration = pending.sessionDuration if pending.sessionDuration != nil {
runner.sessionDuration = pending.sessionDuration.Duration
}
pending.successCh <- runner pending.successCh <- runner
go func() { go func() {
<-pending.ctx.Done() <-pending.ctx.Done()
...@@ -402,6 +404,10 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, ...@@ -402,6 +404,10 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
if numParallel < 1 { if numParallel < 1 {
numParallel = 1 numParallel = 1
} }
sessionDuration := envconfig.KeepAlive
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel) llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
if err != nil { if err != nil {
// some older models are not compatible with newer versions of llama.cpp // some older models are not compatible with newer versions of llama.cpp
...@@ -419,7 +425,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, ...@@ -419,7 +425,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
modelPath: req.model.ModelPath, modelPath: req.model.ModelPath,
llama: llama, llama: llama,
Options: &req.opts, Options: &req.opts,
sessionDuration: req.sessionDuration, sessionDuration: sessionDuration,
gpus: gpus, gpus: gpus,
estimatedVRAM: llama.EstimatedVRAM(), estimatedVRAM: llama.EstimatedVRAM(),
estimatedTotal: llama.EstimatedTotal(), estimatedTotal: llama.EstimatedTotal(),
......
...@@ -44,7 +44,7 @@ func TestLoad(t *testing.T) { ...@@ -44,7 +44,7 @@ func TestLoad(t *testing.T) {
opts: api.DefaultOptions(), opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1), successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1), errCh: make(chan error, 1),
sessionDuration: 2, sessionDuration: &api.Duration{Duration: 2 * time.Second},
} }
// Fail to load model first // Fail to load model first
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
...@@ -142,7 +142,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV ...@@ -142,7 +142,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
ctx: scenario.ctx, ctx: scenario.ctx,
model: model, model: model,
opts: api.DefaultOptions(), opts: api.DefaultOptions(),
sessionDuration: 5 * time.Millisecond, sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
successCh: make(chan *runnerRef, 1), successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1), errCh: make(chan error, 1),
} }
...@@ -156,18 +156,18 @@ func TestRequests(t *testing.T) { ...@@ -156,18 +156,18 @@ func TestRequests(t *testing.T) {
// Same model, same request // Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10) scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = 5 * time.Millisecond scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
scenario1b := newScenario(t, ctx, "ollama-model-1", 11) scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b.req.model = scenario1a.req.model scenario1b.req.model = scenario1a.req.model
scenario1b.ggml = scenario1a.ggml scenario1b.ggml = scenario1a.ggml
scenario1b.req.sessionDuration = 0 scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
// simple reload of same model // simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20) scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
tmpModel := *scenario1a.req.model tmpModel := *scenario1a.req.model
scenario2a.req.model = &tmpModel scenario2a.req.model = &tmpModel
scenario2a.ggml = scenario1a.ggml scenario2a.ggml = scenario1a.ggml
scenario2a.req.sessionDuration = 5 * time.Millisecond scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
// Multiple loaded models // Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte) scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
...@@ -318,11 +318,11 @@ func TestGetRunner(t *testing.T) { ...@@ -318,11 +318,11 @@ func TestGetRunner(t *testing.T) {
defer done() defer done()
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a.req.sessionDuration = 0 scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10) scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
scenario1b.req.sessionDuration = 0 scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10) scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = 0 scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
envconfig.MaxQueuedRequests = 1 envconfig.MaxQueuedRequests = 1
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList { s.getGpuFn = func() gpu.GpuInfoList {
...@@ -402,7 +402,7 @@ func TestPrematureExpired(t *testing.T) { ...@@ -402,7 +402,7 @@ func TestPrematureExpired(t *testing.T) {
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
time.Sleep(scenario1a.req.sessionDuration) time.Sleep(scenario1a.req.sessionDuration.Duration)
scenario1a.ctxDone() scenario1a.ctxDone()
time.Sleep(20 * time.Millisecond) time.Sleep(20 * time.Millisecond)
require.LessOrEqual(t, len(s.finishedReqCh), 1) require.LessOrEqual(t, len(s.finishedReqCh), 1)
...@@ -423,7 +423,7 @@ func TestUseLoadedRunner(t *testing.T) { ...@@ -423,7 +423,7 @@ func TestUseLoadedRunner(t *testing.T) {
ctx: ctx, ctx: ctx,
opts: api.DefaultOptions(), opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1), successCh: make(chan *runnerRef, 1),
sessionDuration: 2, sessionDuration: &api.Duration{Duration: 2},
} }
finished := make(chan *LlmRequest) finished := make(chan *LlmRequest)
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}} llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
...@@ -614,7 +614,7 @@ func TestAlreadyCanceled(t *testing.T) { ...@@ -614,7 +614,7 @@ func TestAlreadyCanceled(t *testing.T) {
dctx, done2 := context.WithCancel(ctx) dctx, done2 := context.WithCancel(ctx)
done2() done2()
scenario1a := newScenario(t, dctx, "ollama-model-1", 10) scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = 0 scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
s := InitScheduler(ctx) s := InitScheduler(ctx)
slog.Info("scenario1a") slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req s.pendingReqCh <- scenario1a.req
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment