process.go 8.51 KB
Newer Older
1
2
3
package restore

import (
4
	"errors"
5
6
7
8
9
10
11
	"fmt"
	"io"
	"os"
	"os/exec"
	"os/signal"
	"strconv"
	"strings"
12
	"sync"
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
	"syscall"
	"time"

	"github.com/sirupsen/logrus"
)

// MonitorProcess monitors the restored process and returns its exit code.
// It blocks until the process exits. Does not forward stdout/stderr.
// For output forwarding, use ForwardProcessOutput instead.
func MonitorProcess(pid int, log *logrus.Entry) int {
	log.WithField("pid", pid).Info("Monitoring restored process")

	for {
		// Check if process still exists by sending signal 0
		proc, err := os.FindProcess(pid)
		if err != nil {
			log.WithError(err).Error("Failed to find process")
			return 1
		}

		err = proc.Signal(syscall.Signal(0))
		if err != nil {
			// Process has exited
			log.WithField("pid", pid).Info("Restored process exited")

			// Try to read exit status from /proc/<pid>/stat
			// If process is gone, assume exit code 0
			exitCode := getExitCode(pid)
			log.WithField("exit_code", exitCode).Info("Restored process exit status")
			return exitCode
		}

		time.Sleep(time.Second)
	}
}

// ForwardProcessOutput forwards the stdout and stderr of a restored process
// to our own stdout/stderr via /proc/<pid>/fd/1 and /proc/<pid>/fd/2.
// This ensures logs from the restored process appear in kubectl logs.
// Returns the exit code of the process.
func ForwardProcessOutput(pid int, log *logrus.Entry) int {
	log.WithField("pid", pid).Info("Forwarding output from restored process")

	// Try to open the process's stdout and stderr via /proc
	stdoutPath := fmt.Sprintf("/proc/%d/fd/1", pid)
	stderrPath := fmt.Sprintf("/proc/%d/fd/2", pid)
59
	var wg sync.WaitGroup
60
61

	// Forward stdout
62
63
	wg.Add(1)
	go forwardFD(stdoutPath, os.Stdout, "stdout", log, &wg)
64
65

	// Forward stderr
66
67
	wg.Add(1)
	go forwardFD(stderrPath, os.Stderr, "stderr", log, &wg)
68

69
	// Wait for process to exit (and reap it if it's our child).
70
71
	exitCode := waitForProcess(pid, log)

72
73
74
75
76
77
78
79
80
81
82
	// Give copy goroutines a short window to flush/finish.
	done := make(chan struct{})
	go func() {
		wg.Wait()
		close(done)
	}()
	select {
	case <-done:
	case <-time.After(2 * time.Second):
		log.WithField("pid", pid).Warn("Timed out waiting for output forwarding goroutines to finish")
	}
83
84
85
86
87
88

	return exitCode
}

// forwardFD copies data from a file descriptor path to a writer.
// It handles the case where the FD may not be readable.
89
90
91
func forwardFD(fdPath string, dst io.Writer, name string, log *logrus.Entry, wg *sync.WaitGroup) {
	defer wg.Done()

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
	// Try to open the FD path
	src, err := os.Open(fdPath)
	if err != nil {
		log.WithError(err).WithField("path", fdPath).Debug("Could not open process FD for forwarding")
		return
	}
	defer src.Close()

	// Check what kind of file this is
	stat, err := src.Stat()
	if err != nil {
		log.WithError(err).WithField("path", fdPath).Debug("Could not stat process FD")
		return
	}

	log.WithFields(logrus.Fields{
		"name": name,
		"mode": stat.Mode().String(),
		"path": fdPath,
	}).Debug("Forwarding process output")

113
114
115
	_, err = io.Copy(dst, src)
	if err != nil && !errors.Is(err, io.EOF) {
		log.WithError(err).WithField("name", name).Debug("Error reading from process FD")
116
117
118
119
120
	}
}

// waitForProcess waits for a process to exit and returns its exit code.
func waitForProcess(pid int, log *logrus.Entry) int {
121
122
123
	// Preferred path: restored process is typically our direct child.
	// Use wait4() so zombies are reaped and exit status is reliable.
	var status syscall.WaitStatus
124
	for {
125
126
127
128
		wpid, err := syscall.Wait4(pid, &status, 0, nil)
		if errors.Is(err, syscall.EINTR) {
			continue
		}
129
		if err != nil {
130
131
132
133
134
			if errors.Is(err, syscall.ECHILD) {
				log.WithField("pid", pid).Warn("Restored process is not a child; falling back to signal-based monitoring")
				return waitForProcessBySignal(pid, log)
			}
			log.WithError(err).WithField("pid", pid).Error("Wait4 failed for restored process")
135
136
			return 1
		}
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
		if wpid != pid {
			continue
		}
		if status.Exited() {
			exitCode := status.ExitStatus()
			log.WithFields(logrus.Fields{
				"pid":       pid,
				"exit_code": exitCode,
			}).Info("Restored process exited")
			return exitCode
		}
		if status.Signaled() {
			exitCode := 128 + int(status.Signal())
			log.WithFields(logrus.Fields{
				"pid":       pid,
				"signal":    status.Signal().String(),
				"exit_code": exitCode,
			}).Warn("Restored process terminated by signal")
155
156
			return exitCode
		}
157
158
159
160
		log.WithField("pid", pid).Warn("Restored process exited with unexpected wait status")
		return 1
	}
}
161

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
func waitForProcessBySignal(pid int, log *logrus.Entry) int {
	for {
		proc, err := os.FindProcess(pid)
		if err != nil {
			log.WithError(err).WithField("pid", pid).Error("Failed to find restored process")
			return 1
		}
		if err := proc.Signal(syscall.Signal(0)); err != nil {
			log.WithField("pid", pid).Info("Restored process no longer exists")
			return 0
		}
		// Detect zombie state when wait4 is unavailable.
		if state, err := readProcState(pid); err == nil && state == "Z" {
			log.WithField("pid", pid).Warn("Restored process is zombie while not reaped by this process")
			return 1
		}
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
		time.Sleep(100 * time.Millisecond)
	}
}

// getExitCode attempts to get the exit code of a process.
// Returns 0 if unable to determine the exit code.
func getExitCode(pid int) int {
	// Try to wait for the process (only works if we're the parent)
	proc, err := os.FindProcess(pid)
	if err != nil {
		return 0
	}

	// Try waitpid with WNOHANG - this may not work for non-child processes
	var wstatus syscall.WaitStatus
	wpid, err := syscall.Wait4(pid, &wstatus, syscall.WNOHANG, nil)
	if err == nil && wpid == pid {
		if wstatus.Exited() {
			return wstatus.ExitStatus()
		}
		if wstatus.Signaled() {
			return 128 + int(wstatus.Signal())
		}
	}

	// If we can't wait on it, check if it's still running
	if proc.Signal(syscall.Signal(0)) != nil {
		// Process is gone, assume clean exit
		return 0
	}

	return 0
}

212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
func readProcState(pid int) (string, error) {
	data, err := os.ReadFile(fmt.Sprintf("/proc/%d/status", pid))
	if err != nil {
		return "", err
	}
	for _, line := range strings.Split(string(data), "\n") {
		if strings.HasPrefix(line, "State:") {
			fields := strings.Fields(line)
			if len(fields) >= 2 {
				return fields[1], nil
			}
			break
		}
	}
	return "", fmt.Errorf("state field not found in /proc/%d/status", pid)
}

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
// SetupSignalForwarding sets up signal forwarding to the restored process.
// Returns a cleanup function that should be called when done.
func SetupSignalForwarding(pid int, log *logrus.Entry) func() {
	sigChan := make(chan os.Signal, 1)
	signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT)

	done := make(chan struct{})

	go func() {
		select {
		case sig := <-sigChan:
			log.WithFields(logrus.Fields{
				"signal": sig,
				"pid":    pid,
			}).Info("Forwarding signal to restored process")

			proc, err := os.FindProcess(pid)
			if err == nil {
				proc.Signal(sig)
			}
		case <-done:
			return
		}
	}()

	return func() {
		signal.Stop(sigChan)
		close(done)
	}
}

// WaitForPidFile waits for the CRIU PID file to be created and returns the PID.
func WaitForPidFile(pidFile string, timeout time.Duration, log *logrus.Entry) (int, error) {
	deadline := time.Now().Add(timeout)

	for time.Now().Before(deadline) {
		data, err := os.ReadFile(pidFile)
		if err == nil {
			pidStr := strings.TrimSpace(string(data))
			pid, err := strconv.Atoi(pidStr)
			if err == nil && pid > 0 {
				return pid, nil
			}
		}
		time.Sleep(100 * time.Millisecond)
	}

	return 0, fmt.Errorf("timeout waiting for PID file %s after %v", pidFile, timeout)
}

279
280
281
282
283
284
// ExecColdStart execs the cold start command (ColdStartArgs), replacing the current process.
// If no args are provided, falls back to sleep infinity.
func ExecColdStart(cfg *RestoreRequest, log *logrus.Entry) error {
	if len(cfg.ColdStartArgs) == 0 {
		log.Warn("No cold start command provided, sleeping indefinitely")
		return ExecArgs([]string{"sleep", "infinity"}, log)
285
286
	}

287
288
	log.WithField("cmd", cfg.ColdStartArgs).Info("Executing cold start command")
	return ExecArgs(cfg.ColdStartArgs, log)
289
290
}

291
292
293
294
// ExecArgs replaces the current process with the given command and arguments.
// Uses syscall.Exec for proper PID 1 behavior in containers.
func ExecArgs(args []string, log *logrus.Entry) error {
	if len(args) == 0 {
295
296
297
298
		return fmt.Errorf("empty command")
	}

	// Find the executable path
299
	path, err := exec.LookPath(args[0])
300
	if err != nil {
301
		return fmt.Errorf("command not found: %s: %w", args[0], err)
302
303
	}

304
305
306
307
308
	log.WithFields(logrus.Fields{
		"path": path,
		"args": args,
	}).Debug("Replacing process via syscall.Exec")

309
310
311
	// Replace current process with the command
	return syscall.Exec(path, args, os.Environ())
}