Unverified Commit c863c6a9 authored by Daniel Hiltgen's avatar Daniel Hiltgen Committed by GitHub
Browse files

Merge pull request #3218 from dhiltgen/subprocess

Switch back to subprocessing for llama.cpp
parents 3b6a9154 1f11b525
...@@ -56,12 +56,13 @@ func init() { ...@@ -56,12 +56,13 @@ func init() {
var loaded struct { var loaded struct {
mu sync.Mutex mu sync.Mutex
runner llm.LLM llama *llm.LlamaServer
expireAt time.Time
expireTimer *time.Timer expireTimer *time.Timer
*Model model string
adapters []string
projectors []string
*api.Options *api.Options
} }
...@@ -69,21 +70,28 @@ var defaultSessionDuration = 5 * time.Minute ...@@ -69,21 +70,28 @@ var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(c *gin.Context, model *Model, opts *api.Options, sessionDuration time.Duration) error { func load(c *gin.Context, model *Model, opts *api.Options, sessionDuration time.Duration) error {
needLoad := loaded.runner == nil || // is there a model loaded? ctx, cancel := context.WithTimeout(c, 10*time.Second)
loaded.ModelPath != model.ModelPath || // has the base model changed? defer cancel()
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed? needLoad := loaded.llama == nil || // is there a model loaded?
loaded.model != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.adapters, model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.projectors, model.ProjectorPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) || // have the runner options changed?
loaded.llama.Ping(ctx) != nil
if needLoad { if needLoad {
if loaded.runner != nil { if loaded.llama != nil {
slog.Info("changing loaded model") slog.Info("changing loaded model")
loaded.runner.Close() loaded.llama.Close()
loaded.runner = nil loaded.llama = nil
loaded.Model = nil loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil loaded.Options = nil
} }
llmRunner, err := llm.New(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts) llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
if err != nil { if err != nil {
// some older models are not compatible with newer versions of llama.cpp // some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to // show a generalized compatibility error until there is a better way to
...@@ -95,28 +103,26 @@ func load(c *gin.Context, model *Model, opts *api.Options, sessionDuration time. ...@@ -95,28 +103,26 @@ func load(c *gin.Context, model *Model, opts *api.Options, sessionDuration time.
return err return err
} }
loaded.Model = model loaded.model = model.ModelPath
loaded.runner = llmRunner loaded.adapters = model.AdapterPaths
loaded.projectors = model.ProjectorPaths
loaded.llama = llama
loaded.Options = opts loaded.Options = opts
} }
loaded.expireAt = time.Now().Add(sessionDuration)
if loaded.expireTimer == nil { if loaded.expireTimer == nil {
loaded.expireTimer = time.AfterFunc(sessionDuration, func() { loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
loaded.mu.Lock() loaded.mu.Lock()
defer loaded.mu.Unlock() defer loaded.mu.Unlock()
if time.Now().Before(loaded.expireAt) { if loaded.llama != nil {
return loaded.llama.Close()
}
if loaded.runner != nil {
loaded.runner.Close()
} }
loaded.runner = nil loaded.llama = nil
loaded.Model = nil loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil loaded.Options = nil
}) })
} }
...@@ -265,7 +271,7 @@ func GenerateHandler(c *gin.Context) { ...@@ -265,7 +271,7 @@ func GenerateHandler(c *gin.Context) {
sb.Reset() sb.Reset()
if req.Context != nil { if req.Context != nil {
prev, err := loaded.runner.Decode(c.Request.Context(), req.Context) prev, err := loaded.llama.Detokenize(c.Request.Context(), req.Context)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
...@@ -286,9 +292,8 @@ func GenerateHandler(c *gin.Context) { ...@@ -286,9 +292,8 @@ func GenerateHandler(c *gin.Context) {
go func() { go func() {
defer close(ch) defer close(ch)
fn := func(r llm.PredictResult) { fn := func(r llm.CompletionResponse) {
// Update model expiration // Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration) loaded.expireTimer.Reset(sessionDuration)
// Build up the full response // Build up the full response
...@@ -322,7 +327,7 @@ func GenerateHandler(c *gin.Context) { ...@@ -322,7 +327,7 @@ func GenerateHandler(c *gin.Context) {
} }
// TODO (jmorganca): encode() should not strip special tokens // TODO (jmorganca): encode() should not strip special tokens
tokens, err := loaded.runner.Encode(c.Request.Context(), p) tokens, err := loaded.llama.Tokenize(c.Request.Context(), p)
if err != nil { if err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
return return
...@@ -344,13 +349,13 @@ func GenerateHandler(c *gin.Context) { ...@@ -344,13 +349,13 @@ func GenerateHandler(c *gin.Context) {
} }
// Start prediction // Start prediction
predictReq := llm.PredictOpts{ req := llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: req.Format,
Images: images, Images: images,
Options: opts, Options: opts,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.llama.Completion(c.Request.Context(), req, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
...@@ -471,7 +476,7 @@ func EmbeddingsHandler(c *gin.Context) { ...@@ -471,7 +476,7 @@ func EmbeddingsHandler(c *gin.Context) {
return return
} }
embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt) embedding, err := loaded.llama.Embedding(c.Request.Context(), req.Prompt)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
...@@ -1123,8 +1128,8 @@ func Serve(ln net.Listener) error { ...@@ -1123,8 +1128,8 @@ func Serve(ln net.Listener) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() { go func() {
<-signals <-signals
if loaded.runner != nil { if loaded.llama != nil {
loaded.runner.Close() loaded.llama.Close()
} }
gpu.Cleanup() gpu.Cleanup()
os.Exit(0) os.Exit(0)
...@@ -1196,7 +1201,7 @@ func streamResponse(c *gin.Context, ch chan any) { ...@@ -1196,7 +1201,7 @@ func streamResponse(c *gin.Context, ch chan any) {
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) { func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) { encode := func(s string) ([]int, error) {
return loaded.runner.Encode(ctx, s) return loaded.llama.Tokenize(ctx, s)
} }
prompt, err := ChatPrompt(template, messages, numCtx, encode) prompt, err := ChatPrompt(template, messages, numCtx, encode)
...@@ -1326,9 +1331,8 @@ func ChatHandler(c *gin.Context) { ...@@ -1326,9 +1331,8 @@ func ChatHandler(c *gin.Context) {
go func() { go func() {
defer close(ch) defer close(ch)
fn := func(r llm.PredictResult) { fn := func(r llm.CompletionResponse) {
// Update model expiration // Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration) loaded.expireTimer.Reset(sessionDuration)
resp := api.ChatResponse{ resp := api.ChatResponse{
...@@ -1352,14 +1356,12 @@ func ChatHandler(c *gin.Context) { ...@@ -1352,14 +1356,12 @@ func ChatHandler(c *gin.Context) {
ch <- resp ch <- resp
} }
// Start prediction if err := loaded.llama.Completion(c.Request.Context(), llm.CompletionRequest{
predictReq := llm.PredictOpts{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: req.Format,
Images: images, Images: images,
Options: opts, Options: opts,
} }, fn); err != nil {
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
......
...@@ -17,7 +17,6 @@ import ( ...@@ -17,7 +17,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
...@@ -211,7 +210,7 @@ func Test_Routes(t *testing.T) { ...@@ -211,7 +210,7 @@ func Test_Routes(t *testing.T) {
}, },
} }
s := Server{} s := &Server{}
router := s.GenerateRoutes() router := s.GenerateRoutes()
httpSrv := httptest.NewServer(router) httpSrv := httptest.NewServer(router)
...@@ -242,27 +241,3 @@ func Test_Routes(t *testing.T) { ...@@ -242,27 +241,3 @@ func Test_Routes(t *testing.T) {
} }
} }
type MockLLM struct {
encoding []int
}
func (llm *MockLLM) Predict(ctx context.Context, pred llm.PredictOpts, fn func(llm.PredictResult)) error {
return nil
}
func (llm *MockLLM) Encode(ctx context.Context, prompt string) ([]int, error) {
return llm.encoding, nil
}
func (llm *MockLLM) Decode(ctx context.Context, tokens []int) (string, error) {
return "", nil
}
func (llm *MockLLM) Embedding(ctx context.Context, input string) ([]float64, error) {
return []float64{}, nil
}
func (llm *MockLLM) Close() {
// do nothing
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment