Unverified Commit c863c6a9 authored by Daniel Hiltgen's avatar Daniel Hiltgen Committed by GitHub
Browse files

Merge pull request #3218 from dhiltgen/subprocess

Switch back to subprocessing for llama.cpp
parents 3b6a9154 1f11b525
......@@ -56,12 +56,13 @@ func init() {
var loaded struct {
mu sync.Mutex
runner llm.LLM
llama *llm.LlamaServer
expireAt time.Time
expireTimer *time.Timer
*Model
model string
adapters []string
projectors []string
*api.Options
}
......@@ -69,21 +70,28 @@ var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(c *gin.Context, model *Model, opts *api.Options, sessionDuration time.Duration) error {
needLoad := loaded.runner == nil || // is there a model loaded?
loaded.ModelPath != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed?
ctx, cancel := context.WithTimeout(c, 10*time.Second)
defer cancel()
needLoad := loaded.llama == nil || // is there a model loaded?
loaded.model != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.adapters, model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.projectors, model.ProjectorPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) || // have the runner options changed?
loaded.llama.Ping(ctx) != nil
if needLoad {
if loaded.runner != nil {
if loaded.llama != nil {
slog.Info("changing loaded model")
loaded.runner.Close()
loaded.runner = nil
loaded.Model = nil
loaded.llama.Close()
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
}
llmRunner, err := llm.New(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
......@@ -95,28 +103,26 @@ func load(c *gin.Context, model *Model, opts *api.Options, sessionDuration time.
return err
}
loaded.Model = model
loaded.runner = llmRunner
loaded.model = model.ModelPath
loaded.adapters = model.AdapterPaths
loaded.projectors = model.ProjectorPaths
loaded.llama = llama
loaded.Options = opts
}
loaded.expireAt = time.Now().Add(sessionDuration)
if loaded.expireTimer == nil {
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
loaded.mu.Lock()
defer loaded.mu.Unlock()
if time.Now().Before(loaded.expireAt) {
return
if loaded.llama != nil {
loaded.llama.Close()
}
if loaded.runner != nil {
loaded.runner.Close()
}
loaded.runner = nil
loaded.Model = nil
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
})
}
......@@ -265,7 +271,7 @@ func GenerateHandler(c *gin.Context) {
sb.Reset()
if req.Context != nil {
prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
prev, err := loaded.llama.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
......@@ -286,9 +292,8 @@ func GenerateHandler(c *gin.Context) {
go func() {
defer close(ch)
fn := func(r llm.PredictResult) {
fn := func(r llm.CompletionResponse) {
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
// Build up the full response
......@@ -322,7 +327,7 @@ func GenerateHandler(c *gin.Context) {
}
// TODO (jmorganca): encode() should not strip special tokens
tokens, err := loaded.runner.Encode(c.Request.Context(), p)
tokens, err := loaded.llama.Tokenize(c.Request.Context(), p)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
......@@ -344,13 +349,13 @@ func GenerateHandler(c *gin.Context) {
}
// Start prediction
predictReq := llm.PredictOpts{
req := llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
if err := loaded.llama.Completion(c.Request.Context(), req, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
......@@ -471,7 +476,7 @@ func EmbeddingsHandler(c *gin.Context) {
return
}
embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
embedding, err := loaded.llama.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
......@@ -1123,8 +1128,8 @@ func Serve(ln net.Listener) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signals
if loaded.runner != nil {
loaded.runner.Close()
if loaded.llama != nil {
loaded.llama.Close()
}
gpu.Cleanup()
os.Exit(0)
......@@ -1196,7 +1201,7 @@ func streamResponse(c *gin.Context, ch chan any) {
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) {
return loaded.runner.Encode(ctx, s)
return loaded.llama.Tokenize(ctx, s)
}
prompt, err := ChatPrompt(template, messages, numCtx, encode)
......@@ -1326,9 +1331,8 @@ func ChatHandler(c *gin.Context) {
go func() {
defer close(ch)
fn := func(r llm.PredictResult) {
fn := func(r llm.CompletionResponse) {
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
resp := api.ChatResponse{
......@@ -1352,14 +1356,12 @@ func ChatHandler(c *gin.Context) {
ch <- resp
}
// Start prediction
predictReq := llm.PredictOpts{
if err := loaded.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
}, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
......
......@@ -17,7 +17,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/version"
)
......@@ -211,7 +210,7 @@ func Test_Routes(t *testing.T) {
},
}
s := Server{}
s := &Server{}
router := s.GenerateRoutes()
httpSrv := httptest.NewServer(router)
......@@ -242,27 +241,3 @@ func Test_Routes(t *testing.T) {
}
}
type MockLLM struct {
encoding []int
}
func (llm *MockLLM) Predict(ctx context.Context, pred llm.PredictOpts, fn func(llm.PredictResult)) error {
return nil
}
func (llm *MockLLM) Encode(ctx context.Context, prompt string) ([]int, error) {
return llm.encoding, nil
}
func (llm *MockLLM) Decode(ctx context.Context, tokens []int) (string, error) {
return "", nil
}
func (llm *MockLLM) Embedding(ctx context.Context, input string) ([]float64, error) {
return []float64{}, nil
}
func (llm *MockLLM) Close() {
// do nothing
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment