Commit 34b9db5a authored by Daniel Hiltgen's avatar Daniel Hiltgen
Browse files

Request and model concurrency

This change adds support for multiple concurrent requests, as well as
loading multiple models by spawning multiple runners. The default
settings are currently set at 1 concurrent request per model and only 1
loaded model at a time, but these can be adjusted by setting
OLLAMA_NUM_PARALLEL and OLLAMA_MAX_LOADED_MODELS.
parent ee448dea
...@@ -4,7 +4,6 @@ package integration ...@@ -4,7 +4,6 @@ package integration
import ( import (
"context" "context"
"net/http"
"testing" "testing"
"time" "time"
...@@ -25,5 +24,5 @@ func TestContextExhaustion(t *testing.T) { ...@@ -25,5 +24,5 @@ func TestContextExhaustion(t *testing.T) {
"num_ctx": 128, "num_ctx": 128,
}, },
} }
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"}) GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
} }
...@@ -5,7 +5,6 @@ package integration ...@@ -5,7 +5,6 @@ package integration
import ( import (
"context" "context"
"encoding/base64" "encoding/base64"
"net/http"
"testing" "testing"
"time" "time"
...@@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) { ...@@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) {
}, },
} }
resp := "the ollamas" // Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam"
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel() defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp}) GenerateTestHelper(ctx, t, req, []string{resp})
} }
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
......
...@@ -4,8 +4,6 @@ package integration ...@@ -4,8 +4,6 @@ package integration
import ( import (
"context" "context"
"net/http"
"sync"
"testing" "testing"
"time" "time"
...@@ -45,25 +43,5 @@ var ( ...@@ -45,25 +43,5 @@ var (
func TestIntegrationSimpleOrcaMini(t *testing.T) { func TestIntegrationSimpleOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120) ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel() defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0]) GenerateTestHelper(ctx, t, req[0], resp[0])
} }
// TODO
// The server always loads a new runner and closes the old one, which forces serial execution
// At present this test case fails with concurrency problems. Eventually we should try to
// get true concurrency working with n_parallel support in the backend
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
var wg sync.WaitGroup
wg.Add(len(req))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
}(i)
}
wg.Wait()
}
// TODO - create a parallel test with 2 different models once we support concurrency
...@@ -5,13 +5,14 @@ package integration ...@@ -5,13 +5,14 @@ package integration
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json" "errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"math/rand" "math/rand"
"net" "net"
"net/http" "net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
...@@ -23,9 +24,13 @@ import ( ...@@ -23,9 +24,13 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle" "github.com/ollama/ollama/app/lifecycle"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func Init() {
lifecycle.InitLogging()
}
func FindPort() string { func FindPort() string {
port := 0 port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
...@@ -41,7 +46,7 @@ func FindPort() string { ...@@ -41,7 +46,7 @@ func FindPort() string {
return strconv.Itoa(port) return strconv.Itoa(port)
} }
func GetTestEndpoint() (string, string) { func GetTestEndpoint() (*api.Client, string) {
defaultPort := "11434" defaultPort := "11434"
ollamaHost := os.Getenv("OLLAMA_HOST") ollamaHost := os.Getenv("OLLAMA_HOST")
...@@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) { ...@@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) {
port = FindPort() port = FindPort()
} }
url := fmt.Sprintf("%s:%s", host, port) slog.Info("server connection", "host", host, "port", port)
slog.Info("server connection", "url", url)
return scheme, url return api.NewClient(
&url.URL{
Scheme: scheme,
Host: net.JoinHostPort(host, port),
},
http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
} }
// TODO make fanicier, grab logs, etc.
var serverMutex sync.Mutex var serverMutex sync.Mutex
var serverReady bool var serverReady bool
func StartServer(ctx context.Context, ollamaHost string) error { func startServer(ctx context.Context, ollamaHost string) error {
// Make sure the server has been built // Make sure the server has been built
CLIName, err := filepath.Abs("../ollama") CLIName, err := filepath.Abs("../ollama")
if err != nil { if err != nil {
...@@ -125,67 +134,76 @@ func StartServer(ctx context.Context, ollamaHost string) error { ...@@ -125,67 +134,76 @@ func StartServer(ctx context.Context, ollamaHost string) error {
return nil return nil
} }
func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error { func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
slog.Info("checking status of model", "model", modelName) slog.Info("checking status of model", "model", modelName)
showReq := &api.ShowRequest{Name: modelName} showReq := &api.ShowRequest{Name: modelName}
requestJSON, err := json.Marshal(showReq)
if err != nil {
return err
}
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
if err != nil {
return err
}
// Make the request with the HTTP client showCtx, cancel := context.WithDeadlineCause(
response, err := client.Do(req.WithContext(ctx)) ctx,
if err != nil { time.Now().Add(5*time.Second),
fmt.Errorf("show for existing model %s took too long", modelName),
)
defer cancel()
_, err := client.Show(showCtx, showReq)
var statusError api.StatusError
switch {
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
break
case err != nil:
return err return err
} default:
defer response.Body.Close()
if response.StatusCode == 200 {
slog.Info("model already present", "model", modelName) slog.Info("model already present", "model", modelName)
return nil return nil
} }
slog.Info("model missing", "status", response.StatusCode) slog.Info("model missing", "model", modelName)
stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
stallTimer := time.NewTimer(stallDuration)
fn := func(resp api.ProgressResponse) error {
// fmt.Print(".")
if !stallTimer.Reset(stallDuration) {
return fmt.Errorf("stall was detected, aborting status reporting")
}
return nil
}
stream := true
pullReq := &api.PullRequest{Name: modelName, Stream: &stream} pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
requestJSON, err = json.Marshal(pullReq)
if err != nil {
return err
}
req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON)) var pullError error
if err != nil {
return err
}
slog.Info("pulling", "model", modelName)
response, err = client.Do(req.WithContext(ctx)) done := make(chan int)
if err != nil { go func() {
return err pullError = client.Pull(ctx, pullReq, fn)
} done <- 0
defer response.Body.Close() }()
if response.StatusCode != 200 {
return fmt.Errorf("failed to pull model") // TODO more details perhaps select {
case <-stallTimer.C:
return fmt.Errorf("download stalled")
case <-done:
return pullError
} }
slog.Info("model pulled", "model", modelName)
return nil
} }
var serverProcMutex sync.Mutex var serverProcMutex sync.Mutex
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) { // Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
// Starts the server if needed
// TODO maybe stuff in an init routine? func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
lifecycle.InitLogging() client, testEndpoint := GetTestEndpoint()
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
requestJSON, err := json.Marshal(genReq) serverProcMutex.Lock()
if err != nil { fp, err := os.CreateTemp("", "ollama-server-*.log")
t.Fatalf("Error serializing request: %v", err) if err != nil {
t.Fatalf("failed to generate log file: %s", err)
}
lifecycle.ServerLogFile = fp.Name()
fp.Close()
require.NoError(t, startServer(ctx, testEndpoint))
} }
defer func() {
return client, testEndpoint, func() {
if os.Getenv("OLLAMA_TEST_EXISTING") == "" { if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
defer serverProcMutex.Unlock() defer serverProcMutex.Unlock()
if t.Failed() { if t.Failed() {
...@@ -203,63 +221,118 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, ...@@ -203,63 +221,118 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
os.Stderr.Write(data) os.Stderr.Write(data)
slog.Warn("END OF SERVER") slog.Warn("END OF SERVER")
} }
err = os.Remove(lifecycle.ServerLogFile) err := os.Remove(lifecycle.ServerLogFile)
if err != nil && !os.IsNotExist(err) { if err != nil && !os.IsNotExist(err) {
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err) slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
} }
} }
}()
scheme, testEndpoint := GetTestEndpoint()
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
serverProcMutex.Lock()
fp, err := os.CreateTemp("", "ollama-server-*.log")
if err != nil {
t.Fatalf("failed to generate log file: %s", err)
}
lifecycle.ServerLogFile = fp.Name()
fp.Close()
assert.NoError(t, StartServer(ctx, testEndpoint))
} }
}
err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model) func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
if err != nil { client, _, cleanup := InitServerConnection(ctx, t)
t.Fatalf("Error pulling model: %v", err) defer cleanup()
} require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
}
// Make the request and get the response func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON)) stallTimer := time.NewTimer(initialTimeout)
if err != nil { var buf bytes.Buffer
t.Fatalf("Error creating request: %v", err) fn := func(response api.GenerateResponse) error {
// fmt.Print(".")
buf.Write([]byte(response.Response))
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
} }
// Set the content type for the request stream := true
req.Header.Set("Content-Type", "application/json") genReq.Stream = &stream
done := make(chan int)
var genErr error
go func() {
genErr = client.Generate(ctx, &genReq, fn)
done <- 0
}()
// Make the request with the HTTP client select {
response, err := client.Do(req.WithContext(ctx)) case <-stallTimer.C:
if err != nil { if buf.Len() == 0 {
t.Fatalf("Error making request: %v", err) t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
} } else {
defer response.Body.Close() t.Errorf("generate stalled. Response so far:%s", buf.String())
body, err := io.ReadAll(response.Body) }
assert.NoError(t, err) case <-done:
assert.Equal(t, response.StatusCode, 200, string(body)) require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
// Verify the response contains the expected data
// Verify the response is valid JSON response := buf.String()
var payload api.GenerateResponse atLeastOne := false
err = json.Unmarshal(body, &payload) for _, resp := range anyResp {
if err != nil { if strings.Contains(strings.ToLower(response), resp) {
assert.NoError(t, err, body) atLeastOne = true
break
}
}
require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
} }
}
// Verify the response contains the expected data // Generate a set of requests
atLeastOne := false // By default each request uses orca-mini as the model
for _, resp := range anyResp { func GenerateRequests() ([]api.GenerateRequest, [][]string) {
if strings.Contains(strings.ToLower(payload.Response), resp) { return []api.GenerateRequest{
atLeastOne = true {
break Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "why is the color of dirt brown?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the origin of independence day?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the composition of air?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
},
},
[][]string{
[]string{"sunlight"},
[]string{"soil", "organic", "earth", "black", "tan"},
[]string{"england", "english", "massachusetts", "pilgrims"},
[]string{"fourth", "july", "declaration", "independence"},
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
} }
}
assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
} }
package llm
import (
"fmt"
"log/slog"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
)
// This algorithm looks for a complete fit to determine if we need to unload other models
func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
var estimatedVRAM uint64
if opts.NumCtx > int(ggml.KV().ContextLength()) {
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
opts.NumCtx = int(ggml.KV().ContextLength())
}
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
// Split up the GPUs by type and try them
for _, gpus := range allGpus.ByLibrary() {
var layerCount int
layerCount, estimatedVRAM = EstimateGPULayers(gpus, ggml, projectors, opts)
if opts.NumGPU < 0 {
if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
return true, estimatedVRAM
}
} else {
if layerCount > 0 && layerCount >= opts.NumGPU {
return true, estimatedVRAM
}
}
}
return false, estimatedVRAM
}
// Given a model and one or more GPU targets, predict how many layers and bytes we can load
// The GPUs provided must all be the same Library
func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64) {
if gpus[0].Library == "cpu" {
return 0, 0
}
var memoryAvailable uint64
for _, info := range gpus {
memoryAvailable += info.FreeMemory
}
slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))
// TODO - this is probably wrong, first GPU vs secondaries will have different overheads
memoryMinimum := gpus[0].MinimumMemory
for _, projector := range projectors {
memoryMinimum += projectorMemoryRequirements(projector)
// multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048)
}
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
if graphPartialOffload == 0 {
graphPartialOffload = ggml.KV().GQA() * kv / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
}
graphFullOffload *= uint64(len(gpus))
graphPartialOffload *= uint64(len(gpus))
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
memoryRequiredTotal := memoryMinimum + graphFullOffload
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
memoryRequiredPartial := memoryMinimum + graphPartialOffload
if memoryRequiredPartial > memoryAvailable {
slog.Debug("insufficient VRAM to load any model layers")
return 0, 0
}
var layerCount int
layers := ggml.Tensors().Layers()
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
// KV is proportional to the number of layers
memoryLayer += kv / ggml.KV().BlockCount()
memoryRequiredTotal += memoryLayer
if memoryAvailable > memoryRequiredPartial+memoryLayer {
memoryRequiredPartial += memoryLayer
layerCount++
}
}
var memoryLayerOutput uint64
for k, v := range layers {
if !strings.HasPrefix(k, "blk.") {
memoryLayerOutput += v.size()
}
}
memoryRequiredTotal += memoryLayerOutput
if memoryAvailable > memoryRequiredTotal {
layerCount = int(ggml.KV().BlockCount()) + 1
memoryRequiredPartial = memoryRequiredTotal
}
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
slog.Info(
"offload to gpu",
slog.Group(
"layers",
// actual number of layers offloaded
"real", opts.NumGPU,
// estimated number of layers that can be offloaded
"estimate", layerCount,
),
slog.Group(
"memory",
// memory available for offloading
"available", format.HumanBytes2(memoryAvailable),
slog.Group(
"required",
// memory required for full offloading
"full", format.HumanBytes2(memoryRequiredTotal),
// memory required to offload layers.estimate layers
"partial", format.HumanBytes2(memoryRequiredPartial),
// memory of KV cache
"kv", format.HumanBytes2(kv),
),
slog.Group(
"weights",
// memory of the weights
"total", format.HumanBytes2(memoryWeights),
// memory of repeating layers
"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(memoryLayerOutput),
),
slog.Group(
"graph",
// memory of graph when fully offloaded
"full", format.HumanBytes2(graphFullOffload),
// memory of graph when not fully offloaded
"partial", format.HumanBytes2(graphPartialOffload),
),
),
)
return layerCount, uint64(memoryRequiredPartial)
}
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
...@@ -138,6 +139,23 @@ func serversForGpu(info gpu.GpuInfo) []string { ...@@ -138,6 +139,23 @@ func serversForGpu(info gpu.GpuInfo) []string {
return servers return servers
} }
// Return the optimal server for this CPU architecture
func serverForCpu() string {
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
return "metal"
}
variant := gpu.GetCPUVariant()
availableServers := availableServers()
if variant != "" {
for cmp := range availableServers {
if cmp == "cpu_"+variant {
return cmp
}
}
}
return "cpu"
}
// extract extracts the embedded files to the target directory // extract extracts the embedded files to the target directory
func extractFiles(targetDir string, glob string) error { func extractFiles(targetDir string, glob string) error {
files, err := fs.Glob(libEmbed, glob) files, err := fs.Glob(libEmbed, glob)
......
...@@ -21,21 +21,43 @@ import ( ...@@ -21,21 +21,43 @@ import (
"strings" "strings"
"time" "time"
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
) )
// LlamaServer is an instance of the llama.cpp server type LlamaServer interface {
type LlamaServer struct { Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, prompt string) ([]float64, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
EstimatedVRAM() uint64
}
// llmServer is an instance of the llama.cpp server
type llmServer struct {
port int port int
cmd *exec.Cmd cmd *exec.Cmd
done chan error // Channel to signal when the process exits done chan error // Channel to signal when the process exits
status *StatusWriter status *StatusWriter
options api.Options options api.Options
// TODO - this should be broken down by GPU
estimatedVRAM uint64 // Estimated usage of VRAM by the loaded model
sem *semaphore.Weighted
} }
func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) { func LoadModel(model string) (*GGML, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
f, err := os.Open(model) f, err := os.Open(model)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -43,10 +65,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option ...@@ -43,10 +65,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
defer f.Close() defer f.Close()
ggml, _, err := DecodeGGML(f) ggml, _, err := DecodeGGML(f)
if err != nil { return ggml, err
return nil, err }
}
// NewLlamaServer will run a server for the given GPUs
// The gpu list must be a single family.
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
var err error
if opts.NumCtx > int(ggml.KV().ContextLength()) { if opts.NumCtx > int(ggml.KV().ContextLength()) {
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength()) slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
opts.NumCtx = int(ggml.KV().ContextLength()) opts.NumCtx = int(ggml.KV().ContextLength())
...@@ -56,130 +81,50 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option ...@@ -56,130 +81,50 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
opts.NumCtx = 4 opts.NumCtx = 4
} }
memoryAvailable, _ := gpu.CheckVRAM() cpuRunner := ""
info := gpu.GetGPUInfo() var estimatedVRAM uint64
var systemMemory uint64
memoryMinimum := info.MinimumMemory if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
for _, projector := range projectors {
memoryMinimum += projectorMemoryRequirements(projector)
// multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048)
}
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv // TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch))) cpuRunner = serverForCpu()
if graphPartialOffload == 0 { } else {
graphPartialOffload = ggml.KV().GQA() * kv / 6 if gpus[0].Library == "metal" {
} memInfo, err := gpu.GetCPUMem()
if err != nil {
if graphFullOffload == 0 { slog.Error("failed to lookup system memory", "error", err)
graphFullOffload = graphPartialOffload } else {
} systemMemory = memInfo.TotalMemory
slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
graphFullOffload *= uint64(info.DeviceCount) }
graphPartialOffload *= uint64(info.DeviceCount)
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
memoryRequiredTotal := memoryMinimum + graphFullOffload
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
memoryRequiredPartial := memoryMinimum + graphPartialOffload
if info.Library != "metal" {
if memoryRequiredPartial > memoryAvailable {
info.Library = "cpu"
}
}
var layerCount int
layers := ggml.Tensors().Layers()
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
// KV is proportional to the number of layers
memoryLayer += kv / ggml.KV().BlockCount()
memoryRequiredTotal += memoryLayer
if memoryAvailable > memoryRequiredPartial+memoryLayer {
memoryRequiredPartial += memoryLayer
layerCount++
} }
} var layers int
layers, estimatedVRAM = EstimateGPULayers(gpus, ggml, projectors, opts)
var memoryLayerOutput uint64
for k, v := range layers { if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
if !strings.HasPrefix(k, "blk.") { // disable partial offloading when model is greater than total system memory as this
memoryLayerOutput += v.size() // can lead to locking up the system
opts.NumGPU = 0
} else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
opts.NumGPU = layers
} }
} }
memoryRequiredTotal += memoryLayerOutput // Loop through potential servers
finalErr := fmt.Errorf("no suitable llama servers found")
if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory {
// disable partial offloading when model is greater than total system memory
opts.NumGPU = 0
} else if memoryAvailable > memoryRequiredTotal {
layerCount = int(ggml.KV().BlockCount()) + 1
memoryRequiredPartial = memoryRequiredTotal
}
if opts.NumGPU < 0 {
opts.NumGPU = layerCount
}
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
slog.Info(
"offload to gpu",
slog.Group(
"layers",
// actual number of layers offloaded
"real", opts.NumGPU,
// estimated number of layers that can be offloaded
"estimate", layerCount,
),
slog.Group(
"memory",
// memory available for offloading
"available", format.HumanBytes2(memoryAvailable),
slog.Group(
"required",
// memory required for full offloading
"full", format.HumanBytes2(memoryRequiredTotal),
// memory required to offload layers.estimate layers
"partial", format.HumanBytes2(memoryRequiredPartial),
// memory of KV cache
"kv", format.HumanBytes2(kv),
),
slog.Group(
"weights",
// memory of the weights
"total", format.HumanBytes2(memoryWeights),
// memory of repeating layers
"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(memoryLayerOutput),
),
slog.Group(
"graph",
// memory of graph when fully offloaded
"full", format.HumanBytes2(graphFullOffload),
// memory of graph when not fully offloaded
"partial", format.HumanBytes2(graphPartialOffload),
),
),
)
if len(adapters) > 1 { if len(adapters) > 1 {
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided") return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
} }
availableServers := availableServers() availableServers := availableServers()
servers := serversForGpu(info) var servers []string
if cpuRunner != "" {
servers = []string{cpuRunner}
} else {
servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
}
demandLib := strings.Trim(os.Getenv("OLLAMA_LLM_LIBRARY"), "\"' ") demandLib := strings.Trim(os.Getenv("OLLAMA_LLM_LIBRARY"), "\"' ")
if demandLib != "" { if demandLib != "" {
serverPath := availableServers[demandLib] serverPath := availableServers[demandLib]
...@@ -192,7 +137,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option ...@@ -192,7 +137,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
} }
if len(servers) == 0 { if len(servers) == 0 {
return nil, fmt.Errorf("no servers found for %v", info) return nil, fmt.Errorf("no servers found for %v", gpus)
} }
params := []string{ params := []string{
...@@ -249,8 +194,18 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option ...@@ -249,8 +194,18 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
params = append(params, "--numa") params = append(params, "--numa")
} }
// Loop through potential servers // "--cont-batching", // TODO - doesn't seem to have any noticeable perf change for multiple requests
var finalErr error numParallel := 1
if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" {
numParallel, err = strconv.Atoi(onp)
if err != nil || numParallel <= 0 {
err = fmt.Errorf("invalid OLLAMA_NUM_PARALLEL=%s must be greater than zero - %w", onp, err)
slog.Error("misconfiguration", "error", err)
return nil, err
}
}
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
for i := 0; i < len(servers); i++ { for i := 0; i < len(servers); i++ {
dir := availableServers[servers[i]] dir := availableServers[servers[i]]
...@@ -275,30 +230,49 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option ...@@ -275,30 +230,49 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
} }
// append the server directory to LD_LIBRARY_PATH/PATH // append the server directory to LD_LIBRARY_PATH/PATH
libraryPaths := []string{dir} libraryPaths := []string{dir}
if libraryPath, ok := os.LookupEnv(pathEnv); ok { if libraryPath, ok := os.LookupEnv(pathEnv); ok {
// Append our runner directory to the path // Append our runner directory to the path
// This will favor system libraries over our bundled library dependencies // This will favor system libraries over our bundled library dependencies
libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...) libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...)
} }
// Note: we always put the dependency path first
// since this was the exact version we verified for AMD GPUs
// and we favor what the user had in their path
if gpus[0].DependencyPath != "" {
// TODO refine for multi-gpu support
libraryPaths = append([]string{gpus[0].DependencyPath}, libraryPaths...)
}
server := filepath.Join(dir, "ollama_llama_server") server := filepath.Join(dir, "ollama_llama_server")
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
server = server + ".exe" server = server + ".exe"
} }
s := &LlamaServer{ s := &llmServer{
port: port, port: port,
cmd: exec.Command(server, finalParams...), cmd: exec.Command(server, finalParams...),
status: NewStatusWriter(os.Stderr), status: NewStatusWriter(os.Stderr),
options: opts, options: opts,
estimatedVRAM: estimatedVRAM,
sem: semaphore.NewWeighted(int64(numParallel)),
} }
libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator))) libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
slog.Debug(libEnv)
s.cmd.Env = append(os.Environ(), libEnv) s.cmd.Env = append(os.Environ(), libEnv)
s.cmd.Stdout = os.Stdout s.cmd.Stdout = os.Stdout
s.cmd.Stderr = s.status s.cmd.Stderr = s.status
// TODO - multiple GPU selection logic...
key, val := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
if key != "" {
s.cmd.Env = append(s.cmd.Env, key+"="+val)
}
slog.Info("starting llama server", "cmd", s.cmd.String()) slog.Info("starting llama server", "cmd", s.cmd.String())
// Log at debug as the environment is inherited and might contain sensitive information
slog.Debug("subprocess", "environment", s.cmd.Env)
if err = s.cmd.Start(); err != nil { if err = s.cmd.Start(); err != nil {
msg := "" msg := ""
...@@ -316,6 +290,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option ...@@ -316,6 +290,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
_ = s.cmd.Wait() _ = s.cmd.Wait()
}() }()
// TODO - make sure this is all wired up correctly
// if err = s.WaitUntilRunning(); err != nil {
// slog.Error("error starting llama server", "server", servers[i], "error", err)
// s.Close()
// finalErr = err
// continue
// }
return s, nil return s, nil
} }
...@@ -353,6 +334,21 @@ const ( // iota is reset to 0 ...@@ -353,6 +334,21 @@ const ( // iota is reset to 0
ServerStatusError ServerStatusError
) )
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "llm server ready"
case ServerStatusNoSlotsAvaialble:
return "llm busy - no slots available"
case ServerStatusLoadingModel:
return "llm server loading model"
case ServerStatusNotResponding:
return "llm server not responding"
default:
return "llm server error"
}
}
type ServerStatusResp struct { type ServerStatusResp struct {
Status string `json:"status"` Status string `json:"status"`
SlotsIdle int `json:"slots_idle"` SlotsIdle int `json:"slots_idle"`
...@@ -360,7 +356,7 @@ type ServerStatusResp struct { ...@@ -360,7 +356,7 @@ type ServerStatusResp struct {
Error string `json:"error"` Error string `json:"error"`
} }
func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) { func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
// Fail fast if its exited // Fail fast if its exited
if s.cmd.ProcessState != nil { if s.cmd.ProcessState != nil {
msg := "" msg := ""
...@@ -407,7 +403,7 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) ...@@ -407,7 +403,7 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
} }
} }
func (s *LlamaServer) Ping(ctx context.Context) error { func (s *llmServer) Ping(ctx context.Context) error {
_, err := s.getServerStatus(ctx) _, err := s.getServerStatus(ctx)
if err != nil { if err != nil {
slog.Debug("server unhealthy", "error", err) slog.Debug("server unhealthy", "error", err)
...@@ -416,7 +412,7 @@ func (s *LlamaServer) Ping(ctx context.Context) error { ...@@ -416,7 +412,7 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
return nil return nil
} }
func (s *LlamaServer) WaitUntilRunning() error { func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
start := time.Now() start := time.Now()
// TODO we need to wire up a better way to detect hangs during model load and startup of the server // TODO we need to wire up a better way to detect hangs during model load and startup of the server
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
...@@ -427,6 +423,9 @@ func (s *LlamaServer) WaitUntilRunning() error { ...@@ -427,6 +423,9 @@ func (s *LlamaServer) WaitUntilRunning() error {
var lastStatus ServerStatus = -1 var lastStatus ServerStatus = -1
for { for {
select { select {
case <-ctx.Done():
slog.Info("context expired before server started")
return fmt.Errorf("timed out waiting for llama runner to start")
case err := <-s.done: case err := <-s.done:
msg := "" msg := ""
if s.status != nil && s.status.LastErrMsg != "" { if s.status != nil && s.status.LastErrMsg != "" {
...@@ -450,9 +449,9 @@ func (s *LlamaServer) WaitUntilRunning() error { ...@@ -450,9 +449,9 @@ func (s *LlamaServer) WaitUntilRunning() error {
return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg) return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
} }
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) c, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
defer cancel() defer cancel()
status, err := s.getServerStatus(ctx) status, err := s.getServerStatus(c)
if err != nil && lastStatus != status { if err != nil && lastStatus != status {
slog.Debug("server not yet available", "error", err) slog.Debug("server not yet available", "error", err)
lastStatus = status lastStatus = status
...@@ -538,7 +537,12 @@ type CompletionResponse struct { ...@@ -538,7 +537,12 @@ type CompletionResponse struct {
EvalDuration time.Duration EvalDuration time.Duration
} }
func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return err
}
defer s.sem.Release(1)
request := map[string]any{ request := map[string]any{
"prompt": req.Prompt, "prompt": req.Prompt,
"stream": true, "stream": true,
...@@ -569,7 +573,7 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn ...@@ -569,7 +573,7 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
if err != nil { if err != nil {
return err return err
} else if status != ServerStatusReady { } else if status != ServerStatusReady {
return fmt.Errorf("unexpected server status: %d", status) return fmt.Errorf("unexpected server status: %s", status.ToString())
} }
if req.Format == "json" { if req.Format == "json" {
...@@ -716,13 +720,18 @@ type EmbeddingResponse struct { ...@@ -716,13 +720,18 @@ type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"` Embedding []float64 `json:"embedding"`
} }
func (s *LlamaServer) Embedding(ctx context.Context, prompt string) ([]float64, error) { func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err
}
defer s.sem.Release(1)
// Make sure the server is ready // Make sure the server is ready
status, err := s.getServerStatus(ctx) status, err := s.getServerStatus(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} else if status != ServerStatusReady { } else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %d", status) return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
} }
data, err := json.Marshal(TokenizeRequest{Content: prompt}) data, err := json.Marshal(TokenizeRequest{Content: prompt})
...@@ -768,13 +777,13 @@ type TokenizeResponse struct { ...@@ -768,13 +777,13 @@ type TokenizeResponse struct {
Tokens []int `json:"tokens"` Tokens []int `json:"tokens"`
} }
func (s *LlamaServer) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
// Make sure the server is ready // Make sure the server is ready
status, err := s.getServerStatus(ctx) status, err := s.getServerStatus(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} else if status != ServerStatusReady { } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
return nil, fmt.Errorf("unexpected server status: %d", status) return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
} }
data, err := json.Marshal(TokenizeRequest{Content: content}) data, err := json.Marshal(TokenizeRequest{Content: content})
...@@ -820,13 +829,13 @@ type DetokenizeResponse struct { ...@@ -820,13 +829,13 @@ type DetokenizeResponse struct {
Content string `json:"content"` Content string `json:"content"`
} }
func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) { func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
// Make sure the server is ready // Make sure the server is ready
status, err := s.getServerStatus(ctx) status, err := s.getServerStatus(ctx)
if err != nil { if err != nil {
return "", err return "", err
} else if status != ServerStatusReady { } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
return "", fmt.Errorf("unexpected server status: %d", status) return "", fmt.Errorf("unexpected server status: %s", status.ToString())
} }
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens}) data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
...@@ -864,7 +873,7 @@ func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, err ...@@ -864,7 +873,7 @@ func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, err
return decoded.Content, nil return decoded.Content, nil
} }
func (s *LlamaServer) Close() error { func (s *llmServer) Close() error {
if s.cmd != nil { if s.cmd != nil {
slog.Debug("stopping llama server") slog.Debug("stopping llama server")
return s.cmd.Process.Kill() return s.cmd.Process.Kill()
...@@ -873,6 +882,10 @@ func (s *LlamaServer) Close() error { ...@@ -873,6 +882,10 @@ func (s *LlamaServer) Close() error {
return nil return nil
} }
func (s *llmServer) EstimatedVRAM() uint64 {
return s.estimatedVRAM
}
func parseDurationMs(ms float64) time.Duration { func parseDurationMs(ms float64) time.Duration {
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms)) dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil { if err != nil {
......
...@@ -15,11 +15,8 @@ import ( ...@@ -15,11 +15,8 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"reflect"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
...@@ -38,7 +35,8 @@ import ( ...@@ -38,7 +35,8 @@ import (
var mode string = gin.DebugMode var mode string = gin.DebugMode
type Server struct { type Server struct {
addr net.Addr addr net.Addr
sched *Scheduler
} }
func init() { func init() {
...@@ -53,88 +51,8 @@ func init() { ...@@ -53,88 +51,8 @@ func init() {
gin.SetMode(mode) gin.SetMode(mode)
} }
var loaded struct {
mu sync.Mutex
llama *llm.LlamaServer
expireTimer *time.Timer
model string
adapters []string
projectors []string
*api.Options
}
var defaultSessionDuration = 5 * time.Minute var defaultSessionDuration = 5 * time.Minute
func unload() {
if loaded.llama != nil {
loaded.llama.Close()
}
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
}
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
ctx, cancel := context.WithTimeout(c, 10*time.Second)
defer cancel()
needLoad := loaded.llama == nil || // is there a model loaded?
loaded.model != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.adapters, model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.projectors, model.ProjectorPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) || // have the runner options changed?
loaded.llama.Ping(ctx) != nil
if needLoad {
if loaded.llama != nil {
slog.Info("changing loaded model")
unload()
}
llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
}
return err
}
loaded.model = model.ModelPath
loaded.adapters = model.AdapterPaths
loaded.projectors = model.ProjectorPaths
loaded.llama = llama
loaded.Options = &opts
if err = llama.WaitUntilRunning(); err != nil {
slog.Error("error loading llama server", "error", err)
unload()
return err
}
}
if loaded.expireTimer == nil {
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
loaded.mu.Lock()
defer loaded.mu.Unlock()
unload()
})
}
loaded.expireTimer.Reset(sessionDuration)
return nil
}
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions() opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil { if err := opts.FromMap(model.Options); err != nil {
...@@ -154,9 +72,7 @@ func isSupportedImageType(image []byte) bool { ...@@ -154,9 +72,7 @@ func isSupportedImageType(image []byte) bool {
return slices.Contains(allowedTypes, contentType) return slices.Contains(allowedTypes, contentType)
} }
func GenerateHandler(c *gin.Context) { func (s *Server) GenerateHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
checkpointStart := time.Now() checkpointStart := time.Now()
var req api.GenerateRequest var req api.GenerateRequest
...@@ -224,7 +140,11 @@ func GenerateHandler(c *gin.Context) { ...@@ -224,7 +140,11 @@ func GenerateHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration sessionDuration = req.KeepAlive.Duration
} }
if err := load(c, model, opts, sessionDuration); err != nil { rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
...@@ -275,7 +195,7 @@ func GenerateHandler(c *gin.Context) { ...@@ -275,7 +195,7 @@ func GenerateHandler(c *gin.Context) {
sb.Reset() sb.Reset()
if req.Context != nil { if req.Context != nil {
prev, err := loaded.llama.Detokenize(c.Request.Context(), req.Context) prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
...@@ -297,9 +217,6 @@ func GenerateHandler(c *gin.Context) { ...@@ -297,9 +217,6 @@ func GenerateHandler(c *gin.Context) {
defer close(ch) defer close(ch)
fn := func(r llm.CompletionResponse) { fn := func(r llm.CompletionResponse) {
// Update model expiration
loaded.expireTimer.Reset(sessionDuration)
// Build up the full response // Build up the full response
if _, err := generated.WriteString(r.Content); err != nil { if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
...@@ -331,7 +248,7 @@ func GenerateHandler(c *gin.Context) { ...@@ -331,7 +248,7 @@ func GenerateHandler(c *gin.Context) {
} }
// TODO (jmorganca): encode() should not strip special tokens // TODO (jmorganca): encode() should not strip special tokens
tokens, err := loaded.llama.Tokenize(c.Request.Context(), p) tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
if err != nil { if err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
return return
...@@ -359,7 +276,7 @@ func GenerateHandler(c *gin.Context) { ...@@ -359,7 +276,7 @@ func GenerateHandler(c *gin.Context) {
Images: images, Images: images,
Options: opts, Options: opts,
} }
if err := loaded.llama.Completion(c.Request.Context(), req, fn); err != nil { if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
...@@ -421,10 +338,7 @@ func getDefaultSessionDuration() time.Duration { ...@@ -421,10 +338,7 @@ func getDefaultSessionDuration() time.Duration {
return defaultSessionDuration return defaultSessionDuration
} }
func EmbeddingsHandler(c *gin.Context) { func (s *Server) EmbeddingsHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
var req api.EmbeddingRequest var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
...@@ -469,7 +383,11 @@ func EmbeddingsHandler(c *gin.Context) { ...@@ -469,7 +383,11 @@ func EmbeddingsHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration sessionDuration = req.KeepAlive.Duration
} }
if err := load(c, model, opts, sessionDuration); err != nil { rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
...@@ -480,7 +398,7 @@ func EmbeddingsHandler(c *gin.Context) { ...@@ -480,7 +398,7 @@ func EmbeddingsHandler(c *gin.Context) {
return return
} }
embedding, err := loaded.llama.Embedding(c.Request.Context(), req.Prompt) embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
...@@ -493,7 +411,7 @@ func EmbeddingsHandler(c *gin.Context) { ...@@ -493,7 +411,7 @@ func EmbeddingsHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
func PullModelHandler(c *gin.Context) { func (s *Server) PullModelHandler(c *gin.Context) {
var req api.PullRequest var req api.PullRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
...@@ -542,7 +460,7 @@ func PullModelHandler(c *gin.Context) { ...@@ -542,7 +460,7 @@ func PullModelHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func PushModelHandler(c *gin.Context) { func (s *Server) PushModelHandler(c *gin.Context) {
var req api.PushRequest var req api.PushRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
...@@ -591,7 +509,7 @@ func PushModelHandler(c *gin.Context) { ...@@ -591,7 +509,7 @@ func PushModelHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func CreateModelHandler(c *gin.Context) { func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest var req api.CreateRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
...@@ -664,7 +582,7 @@ func CreateModelHandler(c *gin.Context) { ...@@ -664,7 +582,7 @@ func CreateModelHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func DeleteModelHandler(c *gin.Context) { func (s *Server) DeleteModelHandler(c *gin.Context) {
var req api.DeleteRequest var req api.DeleteRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
...@@ -709,7 +627,7 @@ func DeleteModelHandler(c *gin.Context) { ...@@ -709,7 +627,7 @@ func DeleteModelHandler(c *gin.Context) {
c.JSON(http.StatusOK, nil) c.JSON(http.StatusOK, nil)
} }
func ShowModelHandler(c *gin.Context) { func (s *Server) ShowModelHandler(c *gin.Context) {
var req api.ShowRequest var req api.ShowRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
...@@ -809,7 +727,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { ...@@ -809,7 +727,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil return resp, nil
} }
func ListModelsHandler(c *gin.Context) { func (s *Server) ListModelsHandler(c *gin.Context) {
models := make([]api.ModelResponse, 0) models := make([]api.ModelResponse, 0)
manifestsPath, err := GetManifestPath() manifestsPath, err := GetManifestPath()
if err != nil { if err != nil {
...@@ -869,7 +787,7 @@ func ListModelsHandler(c *gin.Context) { ...@@ -869,7 +787,7 @@ func ListModelsHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{Models: models}) c.JSON(http.StatusOK, api.ListResponse{Models: models})
} }
func CopyModelHandler(c *gin.Context) { func (s *Server) CopyModelHandler(c *gin.Context) {
var req api.CopyRequest var req api.CopyRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
...@@ -901,7 +819,7 @@ func CopyModelHandler(c *gin.Context) { ...@@ -901,7 +819,7 @@ func CopyModelHandler(c *gin.Context) {
} }
} }
func HeadBlobHandler(c *gin.Context) { func (s *Server) HeadBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest")) path, err := GetBlobsPath(c.Param("digest"))
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
...@@ -916,7 +834,7 @@ func HeadBlobHandler(c *gin.Context) { ...@@ -916,7 +834,7 @@ func HeadBlobHandler(c *gin.Context) {
c.Status(http.StatusOK) c.Status(http.StatusOK)
} }
func CreateBlobHandler(c *gin.Context) { func (s *Server) CreateBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest")) path, err := GetBlobsPath(c.Param("digest"))
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
...@@ -1063,27 +981,27 @@ func (s *Server) GenerateRoutes() http.Handler { ...@@ -1063,27 +981,27 @@ func (s *Server) GenerateRoutes() http.Handler {
allowedHostsMiddleware(s.addr), allowedHostsMiddleware(s.addr),
) )
r.POST("/api/pull", PullModelHandler) r.POST("/api/pull", s.PullModelHandler)
r.POST("/api/generate", GenerateHandler) r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", ChatHandler) r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embeddings", EmbeddingsHandler) r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/create", CreateModelHandler) r.POST("/api/create", s.CreateModelHandler)
r.POST("/api/push", PushModelHandler) r.POST("/api/push", s.PushModelHandler)
r.POST("/api/copy", CopyModelHandler) r.POST("/api/copy", s.CopyModelHandler)
r.DELETE("/api/delete", DeleteModelHandler) r.DELETE("/api/delete", s.DeleteModelHandler)
r.POST("/api/show", ShowModelHandler) r.POST("/api/show", s.ShowModelHandler)
r.POST("/api/blobs/:digest", CreateBlobHandler) r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", HeadBlobHandler) r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
// Compatibility endpoints // Compatibility endpoints
r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler) r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} { for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) { r.Handle(method, "/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running") c.String(http.StatusOK, "Ollama is running")
}) })
r.Handle(method, "/api/tags", ListModelsHandler) r.Handle(method, "/api/tags", s.ListModelsHandler)
r.Handle(method, "/api/version", func(c *gin.Context) { r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version}) c.JSON(http.StatusOK, gin.H{"version": version.Version})
}) })
...@@ -1137,7 +1055,9 @@ func Serve(ln net.Listener) error { ...@@ -1137,7 +1055,9 @@ func Serve(ln net.Listener) error {
} }
} }
s := &Server{addr: ln.Addr()} ctx, done := context.WithCancel(context.Background())
sched := InitScheduler(ctx)
s := &Server{addr: ln.Addr(), sched: sched}
r := s.GenerateRoutes() r := s.GenerateRoutes()
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version)) slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
...@@ -1150,7 +1070,8 @@ func Serve(ln net.Listener) error { ...@@ -1150,7 +1070,8 @@ func Serve(ln net.Listener) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() { go func() {
<-signals <-signals
unload() done()
sched.unloadAllRunners()
gpu.Cleanup() gpu.Cleanup()
os.Exit(0) os.Exit(0)
}() }()
...@@ -1158,12 +1079,12 @@ func Serve(ln net.Listener) error { ...@@ -1158,12 +1079,12 @@ func Serve(ln net.Listener) error {
if err := llm.Init(); err != nil { if err := llm.Init(); err != nil {
return fmt.Errorf("unable to initialize llm library %w", err) return fmt.Errorf("unable to initialize llm library %w", err)
} }
if runtime.GOOS == "linux" { // TODO - windows too
// check compatibility to log warnings s.sched.Run(ctx)
if _, err := gpu.CheckVRAM(); err != nil {
slog.Info(err.Error()) // At startup we retrieve GPU information so we can get log messages before loading a model
} // This will log warnings to the log in case we have problems with detected GPUs
} _ = gpu.GetGPUInfo()
return srvr.Serve(ln) return srvr.Serve(ln)
} }
...@@ -1219,9 +1140,9 @@ func streamResponse(c *gin.Context, ch chan any) { ...@@ -1219,9 +1140,9 @@ func streamResponse(c *gin.Context, ch chan any) {
} }
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) { func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) { encode := func(s string) ([]int, error) {
return loaded.llama.Tokenize(ctx, s) return runner.llama.Tokenize(ctx, s)
} }
prompt, err := ChatPrompt(template, messages, numCtx, encode) prompt, err := ChatPrompt(template, messages, numCtx, encode)
...@@ -1232,10 +1153,7 @@ func chatPrompt(ctx context.Context, template string, messages []api.Message, nu ...@@ -1232,10 +1153,7 @@ func chatPrompt(ctx context.Context, template string, messages []api.Message, nu
return prompt, nil return prompt, nil
} }
func ChatHandler(c *gin.Context) { func (s *Server) ChatHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
checkpointStart := time.Now() checkpointStart := time.Now()
var req api.ChatRequest var req api.ChatRequest
...@@ -1292,7 +1210,11 @@ func ChatHandler(c *gin.Context) { ...@@ -1292,7 +1210,11 @@ func ChatHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration sessionDuration = req.KeepAlive.Duration
} }
if err := load(c, model, opts, sessionDuration); err != nil { rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
...@@ -1309,7 +1231,7 @@ func ChatHandler(c *gin.Context) { ...@@ -1309,7 +1231,7 @@ func ChatHandler(c *gin.Context) {
}, req.Messages...) }, req.Messages...)
} }
prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx) prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
...@@ -1352,8 +1274,6 @@ func ChatHandler(c *gin.Context) { ...@@ -1352,8 +1274,6 @@ func ChatHandler(c *gin.Context) {
defer close(ch) defer close(ch)
fn := func(r llm.CompletionResponse) { fn := func(r llm.CompletionResponse) {
// Update model expiration
loaded.expireTimer.Reset(sessionDuration)
resp := api.ChatResponse{ resp := api.ChatResponse{
Model: req.Model, Model: req.Model,
...@@ -1376,7 +1296,7 @@ func ChatHandler(c *gin.Context) { ...@@ -1376,7 +1296,7 @@ func ChatHandler(c *gin.Context) {
ch <- resp ch <- resp
} }
if err := loaded.llama.Completion(c.Request.Context(), llm.CompletionRequest{ if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: req.Format,
Images: images, Images: images,
......
package server
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"reflect"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"golang.org/x/exp/slices"
)
type LlmRequest struct {
ctx context.Context //nolint:containedctx
model *Model
ggml *llm.GGML // TODO - how large is this, and do we need to free it after we've finished loading?
opts api.Options
sessionDuration time.Duration
successCh chan *runnerRef
errCh chan error
}
type Scheduler struct {
pendingReqCh chan *LlmRequest
finishedReqCh chan *LlmRequest
expiredCh chan *runnerRef
unloadedCh chan interface{}
loaded map[string]*runnerRef
loadedMu sync.Mutex
loadFn func(req *LlmRequest, gpus gpu.GpuInfoList)
newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error)
getGpuFn func() gpu.GpuInfoList
}
// TODO set this to zero after a release or two, to enable multiple models by default
var loadedMax = 1 // Maximum runners; < 1 maps to as many as will fit in VRAM (unlimited for CPU runners)
var maxQueuedRequests = 10 // TODO configurable
func InitScheduler(ctx context.Context) *Scheduler {
maxRunners := os.Getenv("OLLAMA_MAX_LOADED_MODELS")
if maxRunners != "" {
m, err := strconv.Atoi(maxRunners)
if err != nil {
slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
} else {
loadedMax = m
}
}
sched := &Scheduler{
pendingReqCh: make(chan *LlmRequest, maxQueuedRequests),
finishedReqCh: make(chan *LlmRequest, maxQueuedRequests),
expiredCh: make(chan *runnerRef, maxQueuedRequests),
unloadedCh: make(chan interface{}, maxQueuedRequests),
loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer,
getGpuFn: gpu.GetGPUInfo,
}
sched.loadFn = sched.load
return sched
}
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
ggml, err := llm.LoadModel(model.ModelPath)
req := &LlmRequest{
ctx: c,
model: model,
ggml: ggml,
opts: opts,
sessionDuration: sessionDuration,
successCh: make(chan *runnerRef),
errCh: make(chan error, 1),
}
if err != nil {
req.errCh <- err
return req.successCh, req.errCh
}
select {
case s.pendingReqCh <- req:
default:
req.errCh <- fmt.Errorf("server busy, please try again. maximum pending requests exceeded")
}
return req.successCh, req.errCh
}
// Returns immediately, spawns go routines for the scheduler which will shutdown when ctx is done
func (s *Scheduler) Run(ctx context.Context) {
slog.Debug("starting llm scheduler")
go func() {
s.processPending(ctx)
}()
go func() {
s.processCompleted(ctx)
}()
}
func (s *Scheduler) processPending(ctx context.Context) {
for {
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler pending loop")
return
case pending := <-s.pendingReqCh:
// Block other requests until we get this pending request running
for {
var runnerToExpire *runnerRef
s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath]
loadedCount := len(s.loaded)
s.loadedMu.Unlock()
if runner != nil {
if runner.needsReload(ctx, pending) {
runnerToExpire = runner
} else {
// Runner is usable, return it
pending.useLoadedRunner(runner, s.finishedReqCh)
break
}
} else if loadedCount == 0 {
slog.Debug("loading first model", "model", pending.model.ModelPath)
gpus := s.getGpuFn()
g := pickBestFitGPUs(pending, gpus)
if g != nil {
gpus = g
}
s.loadFn(pending, gpus)
break
} else if loadedMax > 0 && loadedCount >= loadedMax {
slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
runnerToExpire = s.findRunnerToUnload(pending)
} else {
// More than one loaded model, so we have to see if the new one fits
// Get a refreshed GPU list
gpus := s.getGpuFn()
// Update free memory from currently loaded models
s.updateFreeSpace(gpus)
gpus = pickBestFitGPUs(pending, gpus)
if gpus != nil {
slog.Debug("new model fits with existing models, loading")
s.loadFn(pending, gpus)
break
}
runnerToExpire = s.findRunnerToUnload(pending)
}
if runnerToExpire == nil {
// Shouildn't happen
slog.Error("runner to expire was nil!")
continue
}
// Trigger an expiration to unload once it's done
runnerToExpire.refMu.Lock()
slog.Debug("resetting model to expire immediately to make room", "model", runnerToExpire.model, "refCount", runnerToExpire.refCount)
if runnerToExpire.expireTimer != nil {
runnerToExpire.expireTimer.Stop()
runnerToExpire.expireTimer = nil
}
runnerToExpire.sessionDuration = 0
if runnerToExpire.refCount <= 0 {
s.expiredCh <- runnerToExpire
}
runnerToExpire.refMu.Unlock()
// Wait for the unload to happen
// Note: at this point we're queueing up all incoming requests, even if they were for
// a different model that's loaded and not scheduled to be removed.
slog.Debug("waiting for pending requests to complete and unload to occur", "model", runnerToExpire.model)
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler pending loop")
return
case <-s.unloadedCh:
slog.Debug("unload completed", "model", runnerToExpire.model)
continue
}
}
case <-s.unloadedCh:
// An unload request when there are no pending request can be ignored
slog.Debug("ignoring unload event with no pending requests")
}
}
}
func (s *Scheduler) processCompleted(ctx context.Context) {
// Process completed requests, expired timers, and unloading models
for {
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler completed loop")
return
case finished := <-s.finishedReqCh:
s.loadedMu.Lock()
runner := s.loaded[finished.model.ModelPath]
s.loadedMu.Unlock()
if runner == nil {
slog.Error("finished requeset signal received after model unloaded", "model", finished.model.ModelPath)
continue
}
runner.refMu.Lock()
runner.refCount--
if runner.refCount <= 0 {
if runner.sessionDuration <= 0 {
slog.Debug("runner with zero duration has gone idle, expiring to unload", "model", runner.model)
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
}
s.expiredCh <- runner
} else if runner.expireTimer == nil {
slog.Debug("runner with non-zero duration has gone idle, adding timer", "model", runner.model, "duration", runner.sessionDuration)
runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() {
slog.Debug("timer expired, expiring to unload", "model", runner.model)
runner.refMu.Lock()
defer runner.refMu.Unlock()
if runner.expireTimer != nil {
runner.expireTimer.Stop()
}
s.expiredCh <- runner
})
} else {
slog.Debug("runner with non-zero duration has gone idle, resetting timer", "model", runner.model, "duration", runner.sessionDuration)
runner.expireTimer.Reset(runner.sessionDuration)
}
}
slog.Debug("after processing request finished event", "model", runner.model, "refCount", runner.refCount)
runner.refMu.Unlock()
case runner := <-s.expiredCh:
slog.Debug("runner expired event received", "model", runner.model)
runner.refMu.Lock()
if runner.refCount > 0 {
// Shouldn't happen, but safeguard to ensure no leaked runners
slog.Debug("expired event with positive ref count, retrying", "model", runner.model, "refCount", runner.refCount)
go func(runner *runnerRef) {
// We can't unload yet, but want to as soon as the current request completes
// So queue up another expired event
time.Sleep(10 * time.Millisecond)
s.expiredCh <- runner
}(runner)
runner.refMu.Unlock()
continue
}
slog.Debug("got lock to unload", "model", runner.model)
runner.unload()
s.loadedMu.Lock()
delete(s.loaded, runner.model)
s.loadedMu.Unlock()
slog.Debug("runner released", "model", runner.model)
runner.refMu.Unlock()
slog.Debug("sending an unloaded event", "model", runner.model)
s.unloadedCh <- struct{}{}
}
}
}
// Complete the pending request and send the runner back to the requester
// Wires up a finished event after the request context is completed
// Updates session duration, and resets expiration timer
func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *LlmRequest) {
runner.refMu.Lock()
defer runner.refMu.Unlock()
runner.refCount++
runner.sessionDuration = pending.sessionDuration
pending.successCh <- runner
go func() {
<-pending.ctx.Done()
slog.Debug("context for request finished")
finished <- pending
}()
}
func (s *Scheduler) load(req *LlmRequest, gpus gpu.GpuInfoList) {
llama, err := s.newServerFn(gpus, req.model.ModelPath, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
}
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
req.errCh <- err
return
}
runner := &runnerRef{}
runner.model = req.model.ModelPath
runner.adapters = req.model.AdapterPaths
runner.projectors = req.model.ProjectorPaths
runner.llama = llama
runner.Options = &req.opts
runner.sessionDuration = req.sessionDuration
runner.gpus = gpus
runner.estimatedVRAM = llama.EstimatedVRAM()
runner.loading = true
runner.refCount = 1
runner.refMu.Lock()
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
slog.Info("loaded runners", "count", len(s.loaded))
s.loadedMu.Unlock()
go func() {
defer runner.refMu.Unlock()
if err = llama.WaitUntilRunning(req.ctx); err != nil {
slog.Error("error loading llama server", "error", err)
runner.refCount--
req.errCh <- err
slog.Debug("triggering expiration for failed load", "model", runner.model)
s.expiredCh <- runner
return
}
slog.Debug("finished setting up runner", "model", req.model.ModelPath)
runner.loading = false
go func() {
<-req.ctx.Done()
slog.Debug("context for request finished")
s.finishedReqCh <- req
}()
req.successCh <- runner
}()
}
func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
type predKey struct {
Library string
ID string
}
predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
s.loadedMu.Lock()
for _, r := range s.loaded {
r.refMu.Lock()
gpuIDs := make([]string, 0, len(r.gpus))
if r.llama != nil {
// TODO this should be broken down by GPU instead of assuming uniform spread
estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
for _, gpu := range r.gpus {
gpuIDs = append(gpuIDs, gpu.ID)
}
for _, gpu := range allGpus {
if slices.Contains(gpuIDs, gpu.ID) {
predMap[predKey{gpu.Library, gpu.ID}] += estimatedVRAMPerGPU
}
}
} else {
slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
}
r.refMu.Unlock()
}
s.loadedMu.Unlock()
// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
for i := range allGpus {
if p, ok := predMap[predKey{allGpus[i].Library, allGpus[i].ID}]; ok {
slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory))
if p > allGpus[i].TotalMemory {
// Shouldn't happen
slog.Warn("predicted usage exceeds VRAM", "gpu", allGpus[i].ID, "totalMemory", allGpus[i].TotalMemory, "predicted", p)
allGpus[i].FreeMemory = 0
} else if (allGpus[i].TotalMemory - p) < allGpus[i].FreeMemory { // predicted free is smaller than reported free, use it
// TODO maybe we should just always trust our numbers, since cuda's free memory reporting is laggy
// and we might unload models we didn't actually need to. The risk is if some other GPU intensive app is loaded
// after we start our first runner, then we'll never acount for that, so picking the smallest free value seems prudent.
allGpus[i].FreeMemory = allGpus[i].TotalMemory - p
}
slog.Info("updated VRAM", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory))
}
}
}
type runnerRef struct {
refMu sync.Mutex
// refCond sync.Cond // Signaled on transition from 1 -> 0 refCount
refCount uint // prevent unloading if > 0
// unloading bool // set to true when we are trying to unload the runner
llama llm.LlamaServer
loading bool // True only during initial load, then false forever
gpus gpu.GpuInfoList // Recorded at time of provisioning
estimatedVRAM uint64
sessionDuration time.Duration
expireTimer *time.Timer
model string
adapters []string
projectors []string
*api.Options
}
// The refMu must already be held when calling unload
func (runner *runnerRef) unload() {
if runner.llama != nil {
runner.llama.Close()
}
runner.llama = nil
runner.adapters = nil
runner.projectors = nil
runner.Options = nil
runner.gpus = nil
}
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
runner.refMu.Lock()
defer runner.refMu.Unlock()
// Ignore the NumGPU settings for comparison
optsExisting := runner.Options.Runner
optsExisting.NumGPU = -1
optsNew := req.opts.Runner
optsNew.NumGPU = -1
timeout := 10 * time.Second
if runner.loading {
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
}
ctx, cancel := context.WithTimeout(ctx, timeout) // BUG -
defer cancel()
if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed?
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
runner.llama.Ping(ctx) != nil {
return true
}
return false
}
type ByDuration []*runnerRef
func (a ByDuration) Len() int { return len(a) }
func (a ByDuration) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByDuration) Less(i, j int) bool {
// uint64 to turn negative time (never unload) to largest
return uint64(a[i].sessionDuration) < uint64(a[j].sessionDuration)
}
// TODO - future consideration to pick runners based on size
// type BySize []*runnerRef
// func (a BySize) Len() int { return len(a) }
// func (a BySize) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// func (a BySize) Less(i, j int) bool { return a[i].estimatedVRAM < a[j].estimatedVRAM }
// pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits
// If the model can not be fit fully within the available GPU(s) nil is returned
func pickBestFitGPUs(req *LlmRequest, gpus gpu.GpuInfoList) gpu.GpuInfoList {
var estimatedVRAM uint64
for _, gl := range gpus.ByLibrary() {
var ok bool
sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...)
// TODO - potentially sort by performance capability, existing models loaded, etc.
// Note: at present, this will favor more VRAM over faster GPU speed in mixed setups
sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
// First attempt to fit the model into a single GPU
for _, g := range sgl {
if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
return []gpu.GpuInfo{g}
}
}
// TODO future refinements
// - if multiple Libraries, see if any single GPU in any Library will fit
// - try subsets of GPUs instead of just falling back to 1 or all in a family
// Now try all the GPUs
if ok, estimatedVRAM = llm.PredictServerFit(gl, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
slog.Debug("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", gl[0].Library, "required", format.HumanBytes2(estimatedVRAM))
return gl
}
}
return nil
}
// findRunnerToUnload finds a runner to unload to make room for a new model
func (s *Scheduler) findRunnerToUnload(req *LlmRequest) *runnerRef {
s.loadedMu.Lock()
runnerList := make([]*runnerRef, 0, len(s.loaded))
for _, r := range s.loaded {
runnerList = append(runnerList, r)
}
s.loadedMu.Unlock()
// In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload
// e.g., if we have multiple options, will one make room for the request?
sort.Sort(ByDuration(runnerList))
// First try to find a runner that's already idle
for _, runner := range runnerList {
runner.refMu.Lock()
rc := runner.refCount
runner.refMu.Unlock()
if rc == 0 {
slog.Debug("found an idle runner to unload")
return runner
}
}
// None appear idle, just wait for the one with the shortest duration
slog.Debug("no idle runners, picking the shortest duration", "count", len(runnerList))
return runnerList[0]
}
func (s *Scheduler) unloadAllRunners() {
s.loadedMu.Lock()
defer s.loadedMu.Unlock()
for model, runner := range s.loaded {
if runner.llama != nil {
slog.Debug("shutting down runner", "model", model)
runner.llama.Close()
}
}
}
package server
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"log/slog"
"os"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func init() {
os.Setenv("OLLAMA_DEBUG", "1")
lifecycle.InitLogging()
}
func TestInitScheduler(t *testing.T) {
ctx, done := context.WithCancel(context.Background())
defer done()
initialMax := loadedMax
s := InitScheduler(ctx)
require.Equal(t, initialMax, loadedMax)
require.NotNil(t, s.loaded)
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "blue")
s = InitScheduler(ctx)
require.Equal(t, initialMax, loadedMax)
require.NotNil(t, s.loaded)
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "0")
s = InitScheduler(ctx)
require.Equal(t, 0, loadedMax)
require.NotNil(t, s.loaded)
}
func TestLoad(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
s := InitScheduler(ctx)
req := &LlmRequest{
ctx: ctx,
model: &Model{ModelPath: "foo"},
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
sessionDuration: 2,
}
// Fail to load model first
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
return nil, fmt.Errorf("something failed to load model blah")
}
gpus := gpu.GpuInfoList{}
s.load(req, gpus)
require.Len(t, req.successCh, 0)
require.Len(t, req.errCh, 1)
require.Len(t, s.loaded, 0)
err := <-req.errCh
require.Contains(t, err.Error(), "this model may be incompatible")
server := &mockLlm{estimatedVRAM: 10}
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
return server, nil
}
s.load(req, gpus)
select {
case err := <-req.errCh:
require.NoError(t, err)
case resp := <-req.successCh:
require.Equal(t, uint64(10), resp.estimatedVRAM)
require.Equal(t, uint(1), resp.refCount)
require.Len(t, s.loaded, 1)
}
req.model.ModelPath = "dummy_model_path"
server.waitResp = fmt.Errorf("wait failure")
s.load(req, gpus)
select {
case err := <-req.errCh:
require.Contains(t, err.Error(), "wait failure")
case resp := <-req.successCh:
t.Errorf("unexpected success %v", resp)
}
runner := s.loaded["dummy_model_path"]
require.NotNil(t, runner)
require.Equal(t, uint(0), runner.refCount)
require.Len(t, s.expiredCh, 1)
}
type bundle struct {
ctx context.Context //nolint:containedctx
ctxDone func()
srv *mockLlm
req *LlmRequest
}
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
return scenario.srv, nil
}
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
scenario := &bundle{}
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName)
assert.Nil(t, err)
defer f.Close()
gguf := llm.NewGGUFV3(binary.LittleEndian)
err = gguf.Encode(f, llm.KV{
"general.architecture": "llama",
"general.name": "name",
"llama.context_length": uint32(32),
"llama.embedding_length": uint32(4096),
"llama.block_count": uint32(1),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(32),
"tokenizer.ggml.tokens": []string{" "},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
})
assert.Nil(t, err)
fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname}
ggml, err := llm.LoadModel(model.ModelPath)
require.NoError(t, err)
scenario.req = &LlmRequest{
ctx: scenario.ctx,
model: model,
ggml: ggml,
sessionDuration: 5 * time.Millisecond,
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
}
scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM}
return scenario
}
func TestRequests(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b.req.model = scenario1a.req.model
scenario1b.req.ggml = scenario1a.req.ggml
scenario1b.req.sessionDuration = 0
// simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
scenario2a.req.model = scenario1a.req.model
scenario2a.req.ggml = scenario1a.req.ggml
// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
scenario3c := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-scenario1a.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario1a.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
// Same runner as first request due to not needing a reload
s.newServerFn = scenario1b.newServer
slog.Info("scenario1b")
s.pendingReqCh <- scenario1b.req
select {
case resp := <-scenario1b.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario1b.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
// Trigger a reload
s.newServerFn = scenario2a.newServer
scenario2a.req.model.AdapterPaths = []string{"new"}
slog.Info("scenario2a")
s.pendingReqCh <- scenario2a.req
// finish first two requests, so model can reload
time.Sleep(1 * time.Millisecond)
scenario1a.ctxDone()
scenario1b.ctxDone()
select {
case resp := <-scenario2a.req.successCh:
require.Equal(t, resp.llama, scenario2a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario2a.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
loadedMax = 1
s.newServerFn = scenario3a.newServer
slog.Info("scenario3a")
s.pendingReqCh <- scenario3a.req
// finish prior request, so new model can load
time.Sleep(1 * time.Millisecond)
scenario2a.ctxDone()
select {
case resp := <-scenario3a.req.successCh:
require.Equal(t, resp.llama, scenario3a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3a.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
require.Len(t, s.loaded, 1)
loadedMax = 0
s.newServerFn = scenario3b.newServer
slog.Info("scenario3b")
s.pendingReqCh <- scenario3b.req
select {
case resp := <-scenario3b.req.successCh:
require.Equal(t, resp.llama, scenario3b.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3b.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
require.Len(t, s.loaded, 2)
// Try to load a model that wont fit
s.newServerFn = scenario3c.newServer
slog.Info("scenario3c")
require.Len(t, s.loaded, 2)
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
time.Sleep(2 * time.Millisecond)
s.pendingReqCh <- scenario3c.req
// finish prior request, so new model can load
time.Sleep(6 * time.Millisecond)
require.Len(t, s.loaded, 1)
scenario3b.ctxDone()
select {
case resp := <-scenario3c.req.successCh:
require.Equal(t, resp.llama, scenario3c.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3c.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
require.Len(t, s.loaded, 1)
}
func TestGetRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
scenario1b.req.sessionDuration = 0
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = 0
maxQueuedRequests = 1
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
slog.Info("scenario1b")
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
require.Len(t, successCh1b, 0)
require.Len(t, errCh1b, 1)
err := <-errCh1b
require.Contains(t, err.Error(), "server busy")
s.Run(ctx)
select {
case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, errCh1a, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
scenario1a.ctxDone()
require.Len(t, s.loaded, 1)
scenario1c.req.model.ModelPath = "bad path"
slog.Info("scenario1c")
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, successCh1c, 0)
require.Len(t, errCh1c, 1)
err = <-errCh1c
require.Contains(t, err.Error(), "bad path")
scenario1b.ctxDone()
time.Sleep(5 * time.Millisecond)
require.Len(t, s.loaded, 0)
}
// TODO - add one scenario that triggers the bogus finished event with positive ref count
func TestPrematureExpired(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, errCh1a, 0)
require.Len(t, s.loaded, 1)
slog.Info("sending premature expired event now")
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
case <-ctx.Done():
t.Errorf("timeout")
}
time.Sleep(scenario1a.req.sessionDuration)
scenario1a.ctxDone()
time.Sleep(20 * time.Millisecond)
require.LessOrEqual(t, len(s.finishedReqCh), 1)
time.Sleep(10 * time.Millisecond)
require.Len(t, s.finishedReqCh, 0)
require.Len(t, s.loaded, 0)
// also shouldn't happen in real life
s.finishedReqCh <- scenario1a.req
time.Sleep(5 * time.Millisecond)
}
func TestUseLoadedRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
req := &LlmRequest{
ctx: ctx,
successCh: make(chan *runnerRef, 1),
sessionDuration: 2,
}
finished := make(chan *LlmRequest)
llm1 := &mockLlm{}
r1 := &runnerRef{llama: llm1, sessionDuration: 1}
req.useLoadedRunner(r1, finished)
require.Equal(t, uint(1), r1.refCount)
require.Equal(t, time.Duration(2), r1.sessionDuration)
select {
case success := <-req.successCh:
require.Equal(t, r1, success)
case <-ctx.Done():
t.Errorf("timeout")
}
done()
fin := <-finished
require.Equal(t, req, fin)
}
func TestUpdateFreeSpace(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
gpus := gpu.GpuInfoList{
{
Library: "a",
ID: "1",
},
{
Library: "a",
ID: "2",
},
}
gpus[0].TotalMemory = 1000
gpus[0].FreeMemory = 900
gpus[1].TotalMemory = 2000
gpus[1].FreeMemory = 1900
llm1 := &mockLlm{estimatedVRAM: 100}
llm2 := &mockLlm{estimatedVRAM: 200}
r1 := &runnerRef{llama: llm1, gpus: gpus}
r2 := &runnerRef{llama: llm2, gpus: gpus}
s := InitScheduler(ctx)
s.loaded["a"] = r1
s.loaded["b"] = r2
s.updateFreeSpace(gpus)
require.Equal(t, uint64(850), gpus[0].FreeMemory)
require.Equal(t, uint64(1850), gpus[1].FreeMemory)
}
func TestFindRunnerToUnload(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
req := &LlmRequest{ctx: ctx}
r1 := &runnerRef{refCount: 1, sessionDuration: 1}
r2 := &runnerRef{sessionDuration: 2}
s := InitScheduler(ctx)
s.loaded["a"] = r1
s.loaded["b"] = r2
resp := s.findRunnerToUnload(req)
require.Equal(t, r2, resp)
r2.refCount = 1
resp = s.findRunnerToUnload(req)
require.Equal(t, r1, resp)
}
func TestNeedsReload(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
llm := &mockLlm{}
runner := &runnerRef{
adapters: []string{"adapter1"},
projectors: []string{"projector1"},
Options: &api.Options{},
llama: llm,
}
req := &LlmRequest{
model: &Model{
AdapterPaths: []string{"adapter2"},
ProjectorPaths: []string{"projector2"},
},
opts: api.Options{},
}
resp := runner.needsReload(ctx, req)
require.True(t, resp)
req.model.AdapterPaths = runner.adapters
resp = runner.needsReload(ctx, req)
require.True(t, resp)
req.model.ProjectorPaths = runner.projectors
runner.loading = true
req.opts.NumBatch = 1234
resp = runner.needsReload(ctx, req)
require.True(t, resp)
req.opts.NumBatch = runner.Options.NumBatch
llm.pingResp = fmt.Errorf("foo")
resp = runner.needsReload(ctx, req)
require.True(t, resp)
llm.pingResp = nil
resp = runner.needsReload(ctx, req)
require.False(t, resp)
req.opts.NumGPU = 99
resp = runner.needsReload(ctx, req)
require.False(t, resp)
}
func TestUnloadAllRunners(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
llm1 := &mockLlm{}
llm2 := &mockLlm{}
s := InitScheduler(ctx)
s.unloadAllRunners()
r1 := &runnerRef{llama: llm1}
r2 := &runnerRef{llama: llm2}
s.loaded["a"] = r1
s.loaded["b"] = r2
s.unloadAllRunners()
require.True(t, llm1.closeCalled)
require.True(t, llm2.closeCalled)
}
func TestUnload(t *testing.T) {
llm1 := &mockLlm{}
r1 := &runnerRef{llama: llm1}
r2 := &runnerRef{adapters: []string{"A"}}
r1.unload()
require.True(t, llm1.closeCalled)
r2.unload()
require.Nil(t, r2.adapters)
}
type mockLlm struct {
pingResp error
waitResp error
completionResp error
embeddingResp []float64
embeddingRespErr error
tokenizeResp []int
tokenizeRespErr error
detokenizeResp string
detonekizeRespErr error
closeResp error
closeCalled bool
estimatedVRAM uint64
}
func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp }
func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp }
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
return s.completionResp
}
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
return s.embeddingResp, s.embeddingRespErr
}
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
return s.tokenizeResp, s.tokenizeRespErr
}
func (s *mockLlm) Detokenize(ctx context.Context, tokens []int) (string, error) {
return s.detokenizeResp, s.detonekizeRespErr
}
func (s *mockLlm) Close() error {
s.closeCalled = true
return s.closeResp
}
func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment