Commit ac7a842e authored by Michael Yang's avatar Michael Yang
Browse files

fix model reloading

ensure runtime model changes (template, system prompt, messages,
options) are captured on model updates without needing to reload the
server
parent 2c3fe1fd
...@@ -679,7 +679,7 @@ type CompletionRequest struct { ...@@ -679,7 +679,7 @@ type CompletionRequest struct {
Prompt string Prompt string
Format string Format string
Images []ImageData Images []ImageData
Options api.Options Options *api.Options
} }
type CompletionResponse struct { type CompletionResponse struct {
......
...@@ -69,23 +69,25 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options ...@@ -69,23 +69,25 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
return opts, nil return opts, nil
} }
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) { // scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
if name == "" { if name == "" {
return nil, fmt.Errorf("model %w", errRequired) return nil, nil, nil, fmt.Errorf("model %w", errRequired)
} }
model, err := GetModel(name) model, err := GetModel(name)
if err != nil { if err != nil {
return nil, err return nil, nil, nil, err
} }
if err := model.CheckCapabilities(caps...); err != nil { if err := model.CheckCapabilities(caps...); err != nil {
return nil, fmt.Errorf("%s %w", name, err) return nil, nil, nil, fmt.Errorf("%s %w", name, err)
} }
opts, err := modelOptions(model, requestOpts) opts, err := modelOptions(model, requestOpts)
if err != nil { if err != nil {
return nil, err return nil, nil, nil, err
} }
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive) runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
...@@ -93,10 +95,10 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil ...@@ -93,10 +95,10 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
select { select {
case runner = <-runnerCh: case runner = <-runnerCh:
case err = <-errCh: case err = <-errCh:
return nil, err return nil, nil, nil, err
} }
return runner, nil return runner.llama, model, &opts, nil
} }
func (s *Server) GenerateHandler(c *gin.Context) { func (s *Server) GenerateHandler(c *gin.Context) {
...@@ -118,7 +120,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -118,7 +120,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
caps := []Capability{CapabilityCompletion} caps := []Capability{CapabilityCompletion}
r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return return
...@@ -147,8 +149,8 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -147,8 +149,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var msgs []api.Message var msgs []api.Message
if req.System != "" { if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System}) msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if r.model.System != "" { } else if m.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: r.model.System}) msgs = append(msgs, api.Message{Role: "system", Content: m.System})
} }
for _, i := range images { for _, i := range images {
...@@ -157,7 +159,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -157,7 +159,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt}) msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
tmpl := r.model.Template tmpl := m.Template
if req.Template != "" { if req.Template != "" {
tmpl, err = template.Parse(req.Template) tmpl, err = template.Parse(req.Template)
if err != nil { if err != nil {
...@@ -168,7 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -168,7 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var b bytes.Buffer var b bytes.Buffer
if req.Context != nil { if req.Context != nil {
s, err := r.llama.Detokenize(c.Request.Context(), req.Context) s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
...@@ -190,11 +192,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -190,11 +192,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: *r.Options, Options: opts,
}, func(r llm.CompletionResponse) { }, func(r llm.CompletionResponse) {
ch <- api.GenerateResponse{ ch <- api.GenerateResponse{
Model: req.Model, Model: req.Model,
...@@ -254,7 +256,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { ...@@ -254,7 +256,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil { if err != nil {
handleScheduleError(c, req.Model, err) handleScheduleError(c, req.Model, err)
return return
...@@ -266,7 +268,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { ...@@ -266,7 +268,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
embedding, err := r.llama.Embedding(c.Request.Context(), req.Prompt) embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
...@@ -1130,7 +1132,7 @@ func (s *Server) ChatHandler(c *gin.Context) { ...@@ -1130,7 +1132,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
caps := []Capability{CapabilityCompletion} caps := []Capability{CapabilityCompletion}
r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return return
...@@ -1150,7 +1152,7 @@ func (s *Server) ChatHandler(c *gin.Context) { ...@@ -1150,7 +1152,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
prompt, images, err := chatPrompt(c.Request.Context(), r.model, r.llama.Tokenize, r.Options, req.Messages) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
...@@ -1161,11 +1163,11 @@ func (s *Server) ChatHandler(c *gin.Context) { ...@@ -1161,11 +1163,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: *r.Options, Options: opts,
}, func(r llm.CompletionResponse) { }, func(r llm.CompletionResponse) {
ch <- api.ChatResponse{ ch <- api.ChatResponse{
Model: req.Model, Model: req.Model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment