Unverified Commit abed273d authored by Patrick Devine's avatar Patrick Devine Committed by GitHub
Browse files

add "stop" command (#6739)

parent 03439262
...@@ -346,6 +346,39 @@ func (w *progressWriter) Write(p []byte) (n int, err error) { ...@@ -346,6 +346,39 @@ func (w *progressWriter) Write(p []byte) (n int, err error) {
return len(p), nil return len(p), nil
} }
func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
p := progress.NewProgress(os.Stderr)
defer p.StopAndClear()
spinner := progress.NewSpinner("")
p.Add("", spinner)
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
req := &api.GenerateRequest{
Model: opts.Model,
KeepAlive: opts.KeepAlive,
}
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
}
func StopHandler(cmd *cobra.Command, args []string) error {
opts := &runOptions{
Model: args[0],
KeepAlive: &api.Duration{Duration: 0},
}
if err := loadOrUnloadModel(cmd, opts); err != nil {
if strings.Contains(err.Error(), "not found") {
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
}
}
return nil
}
func RunHandler(cmd *cobra.Command, args []string) error { func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true interactive := true
...@@ -424,7 +457,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { ...@@ -424,7 +457,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
opts.ParentModel = info.Details.ParentModel opts.ParentModel = info.Details.ParentModel
if interactive { if interactive {
if err := loadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
return err return err
} }
...@@ -615,7 +648,15 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error { ...@@ -615,7 +648,15 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error {
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100) cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
procStr = fmt.Sprintf("%d%%/%d%% CPU/GPU", int(cpuPercent), int(100-cpuPercent)) procStr = fmt.Sprintf("%d%%/%d%% CPU/GPU", int(cpuPercent), int(100-cpuPercent))
} }
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, format.HumanTime(m.ExpiresAt, "Never")})
var until string
delta := time.Since(m.ExpiresAt)
if delta > 0 {
until = "Stopping..."
} else {
until = format.HumanTime(m.ExpiresAt, "Never")
}
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, until})
} }
} }
...@@ -1294,6 +1335,15 @@ func NewCLI() *cobra.Command { ...@@ -1294,6 +1335,15 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("insecure", false, "Use an insecure registry") runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically") runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
runCmd.Flags().String("format", "", "Response format (e.g. json)") runCmd.Flags().String("format", "", "Response format (e.g. json)")
stopCmd := &cobra.Command{
Use: "stop MODEL",
Short: "Stop a running model",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: StopHandler,
}
serveCmd := &cobra.Command{ serveCmd := &cobra.Command{
Use: "serve", Use: "serve",
Aliases: []string{"start"}, Aliases: []string{"start"},
...@@ -1361,6 +1411,7 @@ func NewCLI() *cobra.Command { ...@@ -1361,6 +1411,7 @@ func NewCLI() *cobra.Command {
createCmd, createCmd,
showCmd, showCmd,
runCmd, runCmd,
stopCmd,
pullCmd, pullCmd,
pushCmd, pushCmd,
listCmd, listCmd,
...@@ -1400,6 +1451,7 @@ func NewCLI() *cobra.Command { ...@@ -1400,6 +1451,7 @@ func NewCLI() *cobra.Command {
createCmd, createCmd,
showCmd, showCmd,
runCmd, runCmd,
stopCmd,
pullCmd, pullCmd,
pushCmd, pushCmd,
listCmd, listCmd,
......
...@@ -18,7 +18,6 @@ import ( ...@@ -18,7 +18,6 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline" "github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
) )
...@@ -31,26 +30,6 @@ const ( ...@@ -31,26 +30,6 @@ const (
MultilineSystem MultilineSystem
) )
func loadModel(cmd *cobra.Command, opts *runOptions) error {
p := progress.NewProgress(os.Stderr)
defer p.StopAndClear()
spinner := progress.NewSpinner("")
p.Add("", spinner)
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
chatReq := &api.ChatRequest{
Model: opts.Model,
KeepAlive: opts.KeepAlive,
}
return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil })
}
func generateInteractive(cmd *cobra.Command, opts runOptions) error { func generateInteractive(cmd *cobra.Command, opts runOptions) error {
usage := func() { usage := func() {
fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, "Available Commands:")
...@@ -217,7 +196,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -217,7 +196,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Model = args[1] opts.Model = args[1]
opts.Messages = []api.Message{} opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model) fmt.Printf("Loading model '%s'\n", opts.Model)
if err := loadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
return err return err
} }
continue continue
......
...@@ -117,6 +117,32 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -117,6 +117,32 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
// expire the runner
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
model, err := GetModel(req.Model)
if err != nil {
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == "invalid model name":
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
s.sched.expireRunner(model)
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: "",
Done: true,
DoneReason: "unload",
})
return
}
if req.Format != "" && req.Format != "json" { if req.Format != "" && req.Format != "json" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
return return
...@@ -1322,6 +1348,32 @@ func (s *Server) ChatHandler(c *gin.Context) { ...@@ -1322,6 +1348,32 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
// expire the runner
if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
model, err := GetModel(req.Model)
if err != nil {
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == "invalid model name":
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
s.sched.expireRunner(model)
c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant"},
Done: true,
DoneReason: "unload",
})
return
}
caps := []Capability{CapabilityCompletion} caps := []Capability{CapabilityCompletion}
if len(req.Tools) > 0 { if len(req.Tools) > 0 {
caps = append(caps, CapabilityTools) caps = append(caps, CapabilityTools)
......
...@@ -360,7 +360,6 @@ func (s *Scheduler) processCompleted(ctx context.Context) { ...@@ -360,7 +360,6 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
slog.Debug("runner expired event received", "modelPath", runner.modelPath) slog.Debug("runner expired event received", "modelPath", runner.modelPath)
runner.refMu.Lock() runner.refMu.Lock()
if runner.refCount > 0 { if runner.refCount > 0 {
// Shouldn't happen, but safeguard to ensure no leaked runners
slog.Debug("expired event with positive ref count, retrying", "modelPath", runner.modelPath, "refCount", runner.refCount) slog.Debug("expired event with positive ref count, retrying", "modelPath", runner.modelPath, "refCount", runner.refCount)
go func(runner *runnerRef) { go func(runner *runnerRef) {
// We can't unload yet, but want to as soon as the current request completes // We can't unload yet, but want to as soon as the current request completes
...@@ -802,6 +801,25 @@ func (s *Scheduler) unloadAllRunners() { ...@@ -802,6 +801,25 @@ func (s *Scheduler) unloadAllRunners() {
} }
} }
func (s *Scheduler) expireRunner(model *Model) {
s.loadedMu.Lock()
defer s.loadedMu.Unlock()
runner, ok := s.loaded[model.ModelPath]
if ok {
runner.refMu.Lock()
runner.expiresAt = time.Now()
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
}
runner.sessionDuration = 0
if runner.refCount <= 0 {
s.expiredCh <- runner
}
runner.refMu.Unlock()
}
}
// If other runners are loaded, make sure the pending request will fit in system memory // If other runners are loaded, make sure the pending request will fit in system memory
// If not, pick a runner to unload, else return nil and the request can be loaded // If not, pick a runner to unload, else return nil and the request can be loaded
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) *runnerRef { func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) *runnerRef {
......
...@@ -406,6 +406,52 @@ func TestGetRunner(t *testing.T) { ...@@ -406,6 +406,52 @@ func TestGetRunner(t *testing.T) {
b.ctxDone() b.ctxDone()
} }
func TestExpireRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer done()
s := InitScheduler(ctx)
req := &LlmRequest{
ctx: ctx,
model: &Model{ModelPath: "foo"},
opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
sessionDuration: &api.Duration{Duration: 2 * time.Minute},
}
var ggml *llm.GGML
gpus := gpu.GpuInfoList{}
server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
return server, nil
}
s.load(req, ggml, gpus, 0)
select {
case err := <-req.errCh:
if err != nil {
t.Fatalf("expected no errors when loading, got '%s'", err.Error())
}
case resp := <-req.successCh:
s.loadedMu.Lock()
if resp.refCount != uint(1) || len(s.loaded) != 1 {
t.Fatalf("expected a model to be loaded")
}
s.loadedMu.Unlock()
}
s.expireRunner(&Model{ModelPath: "foo"})
s.finishedReqCh <- req
s.processCompleted(ctx)
s.loadedMu.Lock()
if len(s.loaded) != 0 {
t.Fatalf("expected model to be unloaded")
}
s.loadedMu.Unlock()
}
// TODO - add one scenario that triggers the bogus finished event with positive ref count // TODO - add one scenario that triggers the bogus finished event with positive ref count
func TestPrematureExpired(t *testing.T) { func TestPrematureExpired(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond) ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment