server.go 4.24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
package lifecycle

import (
	"context"
	"errors"
	"fmt"
	"io"
	"log/slog"
	"os"
	"os/exec"
	"path/filepath"
	"time"

14
	"github.com/ollama/ollama/api"
15
16
17
)

func getCLIFullPath(command string) string {
Michael Yang's avatar
lint  
Michael Yang committed
18
	var cmdPath string
19
20
	appExe, err := os.Executable()
	if err == nil {
21
		// Check both the same location as the tray app, as well as ./bin
22
23
24
25
26
		cmdPath = filepath.Join(filepath.Dir(appExe), command)
		_, err := os.Stat(cmdPath)
		if err == nil {
			return cmdPath
		}
27
28
29
30
31
		cmdPath = filepath.Join(filepath.Dir(appExe), "bin", command)
		_, err = os.Stat(cmdPath)
		if err == nil {
			return cmdPath
		}
32
33
34
35
36
37
38
39
	}
	cmdPath, err = exec.LookPath(command)
	if err == nil {
		_, err := os.Stat(cmdPath)
		if err == nil {
			return cmdPath
		}
	}
40
	pwd, err := os.Getwd()
41
	if err == nil {
42
43
44
45
46
		cmdPath = filepath.Join(pwd, command)
		_, err = os.Stat(cmdPath)
		if err == nil {
			return cmdPath
		}
47
	}
48

49
50
51
	return command
}

52
func start(ctx context.Context, command string) (*exec.Cmd, error) {
53
54
55
	cmd := getCmd(ctx, getCLIFullPath(command))
	stdout, err := cmd.StdoutPipe()
	if err != nil {
56
		return nil, fmt.Errorf("failed to spawn server stdout pipe: %w", err)
57
58
59
	}
	stderr, err := cmd.StderrPipe()
	if err != nil {
60
		return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
61
62
	}

63
	rotateLogs(ServerLogFile)
Michael Yang's avatar
lint  
Michael Yang committed
64
	logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o755)
65
	if err != nil {
66
		return nil, fmt.Errorf("failed to create server log: %w", err)
67
	}
68
69
70
71
72
73
74
75
76
77
78
79
80

	logDir := filepath.Dir(ServerLogFile)
	_, err = os.Stat(logDir)
	if err != nil {
		if !errors.Is(err, os.ErrNotExist) {
			return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
		}

		if err := os.MkdirAll(logDir, 0o755); err != nil {
			return nil, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
		}
	}

81
82
83
84
85
86
87
88
89
	go func() {
		defer logFile.Close()
		io.Copy(logFile, stdout) //nolint:errcheck
	}()
	go func() {
		defer logFile.Close()
		io.Copy(logFile, stderr) //nolint:errcheck
	}()

90
91
92
	// Re-wire context done behavior to attempt a graceful shutdown of the server
	cmd.Cancel = func() error {
		if cmd.Process != nil {
93
94
95
96
97
98
			err := terminate(cmd)
			if err != nil {
				slog.Warn("error trying to gracefully terminate server", "err", err)
				return cmd.Process.Kill()
			}

99
100
			tick := time.NewTicker(10 * time.Millisecond)
			defer tick.Stop()
101

102
103
104
			for {
				select {
				case <-tick.C:
105
106
107
108
109
110
111
					exited, err := isProcessExited(cmd.Process.Pid)
					if err != nil {
						return err
					}

					if exited {
						return nil
112
113
114
					}
				case <-time.After(5 * time.Second):
					slog.Warn("graceful server shutdown timeout, killing", "pid", cmd.Process.Pid)
115
					return cmd.Process.Kill()
116
117
118
119
120
121
				}
			}
		}
		return nil
	}

122
123
	// run the command and wait for it to finish
	if err := cmd.Start(); err != nil {
124
		return nil, fmt.Errorf("failed to start server %w", err)
125
126
127
128
129
130
	}
	if cmd.Process != nil {
		slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
	}
	slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))

131
132
133
134
135
136
	return cmd, nil
}

func SpawnServer(ctx context.Context, command string) (chan int, error) {
	done := make(chan int)

137
138
139
140
	go func() {
		// Keep the server running unless we're shuttind down the app
		crashCount := 0
		for {
141
142
143
144
145
146
147
148
149
			slog.Info("starting server...")
			cmd, err := start(ctx, command)
			if err != nil {
				crashCount++
				slog.Error(fmt.Sprintf("failed to start server %s", err))
				time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
				continue
			}

150
151
152
153
154
155
156
157
			cmd.Wait() //nolint:errcheck
			var code int
			if cmd.ProcessState != nil {
				code = cmd.ProcessState.ExitCode()
			}

			select {
			case <-ctx.Done():
158
				slog.Info(fmt.Sprintf("server shutdown with exit code %d", code))
159
160
161
162
163
				done <- code
				return
			default:
				crashCount++
				slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
164
165
				time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
				break
166
167
168
			}
		}
	}()
169

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
	return done, nil
}

func IsServerRunning(ctx context.Context) bool {
	client, err := api.ClientFromEnvironment()
	if err != nil {
		slog.Info("unable to connect to server")
		return false
	}
	err = client.Heartbeat(ctx)
	if err != nil {
		slog.Debug(fmt.Sprintf("heartbeat from server: %s", err))
		slog.Info("unable to connect to server")
		return false
	}
	return true
}