Commit 3c75113e authored by Daniel Hiltgen's avatar Daniel Hiltgen
Browse files

Prevent loading models larger than total memory

Users may not realize the siny new model they're trying to load
fits on their disk, but can't load into system+GPU memory.  Today
we crash, but with this fix, we'll give them a better error message
before even trying to load it.
parent 3b5a4a77
...@@ -139,6 +139,11 @@ func (s *Scheduler) processPending(ctx context.Context) { ...@@ -139,6 +139,11 @@ func (s *Scheduler) processPending(ctx context.Context) {
} }
for { for {
cpus := s.getCpuFn()
var systemMem gpu.GpuInfo
if len(cpus) > 0 {
systemMem = cpus[0]
}
var runnerToExpire *runnerRef var runnerToExpire *runnerRef
s.loadedMu.Lock() s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath] runner := s.loaded[pending.model.ModelPath]
...@@ -192,6 +197,27 @@ func (s *Scheduler) processPending(ctx context.Context) { ...@@ -192,6 +197,27 @@ func (s *Scheduler) processPending(ctx context.Context) {
break break
} }
// Block attempting to load a model larger than system memory + GPU memory
estimate := llm.EstimateGPULayers(gpus, ggml, pending.model.ProjectorPaths, pending.opts)
maxSize := systemMem.FreeMemory
for _, gpu := range gpus {
if gpu.Library == "cpu" {
continue
}
if loadedCount == 0 {
// If no other models are loaded, set the limit based on what's available
maxSize += gpu.FreeMemory
} else {
// Other models could be unloaded, favor total memory for limit
maxSize += gpu.TotalMemory
}
}
if estimate.TotalSize > maxSize {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(estimate.TotalSize), "system", format.HumanBytes2(maxSize))
pending.errCh <- fmt.Errorf("requested model (%s) is too large for this system (%s)", format.HumanBytes2(estimate.TotalSize), format.HumanBytes2(maxSize))
break
}
// Evaluate if the model will fit in the available system memory, or if we should unload a model first // Evaluate if the model will fit in the available system memory, or if we should unload a model first
if len(gpus) == 1 && gpus[0].Library == "cpu" { if len(gpus) == 1 && gpus[0].Library == "cpu" {
// simplifying assumption of defaultParallel when in CPU mode // simplifying assumption of defaultParallel when in CPU mode
......
...@@ -199,6 +199,8 @@ func TestRequests(t *testing.T) { ...@@ -199,6 +199,8 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1a.req.errCh) require.Empty(t, scenario1a.req.errCh)
case err := <-scenario1a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
...@@ -212,6 +214,8 @@ func TestRequests(t *testing.T) { ...@@ -212,6 +214,8 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1b.req.errCh) require.Empty(t, scenario1b.req.errCh)
case err := <-scenario1b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
...@@ -230,6 +234,8 @@ func TestRequests(t *testing.T) { ...@@ -230,6 +234,8 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario2a.srv) require.Equal(t, resp.llama, scenario2a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario2a.req.errCh) require.Empty(t, scenario2a.req.errCh)
case err := <-scenario2a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
...@@ -246,6 +252,8 @@ func TestRequests(t *testing.T) { ...@@ -246,6 +252,8 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario3a.srv) require.Equal(t, resp.llama, scenario3a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3a.req.errCh) require.Empty(t, scenario3a.req.errCh)
case err := <-scenario3a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
...@@ -262,6 +270,8 @@ func TestRequests(t *testing.T) { ...@@ -262,6 +270,8 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario3b.srv) require.Equal(t, resp.llama, scenario3b.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3b.req.errCh) require.Empty(t, scenario3b.req.errCh)
case err := <-scenario3b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
...@@ -278,6 +288,8 @@ func TestRequests(t *testing.T) { ...@@ -278,6 +288,8 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario3c.srv) require.Equal(t, resp.llama, scenario3c.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3c.req.errCh) require.Empty(t, scenario3c.req.errCh)
case err := <-scenario3c.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment