main.go 3.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// Package main provides the CRIU node agent with HTTP API and/or pod watching.
// The agent supports two modes that can be enabled independently:
// - HTTP API mode: Exposes REST endpoints for checkpoint/restore operations
// - Watcher mode: Automatically checkpoints pods with nvidia.com/checkpoint-source=true label
package main

import (
	"context"
	"log"
	"net/http"
	"os"
	"os/signal"
	"syscall"
	"time"

	"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint"
17
	httpApiServer "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/http_api_server"
18
19
20
21
	"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/watcher"
)

func main() {
22
23
24
25
	// Load configuration from ConfigMap (or use defaults if not found)
	cfg, err := LoadConfigOrDefault(ConfigMapPath)
	if err != nil {
		log.Fatalf("Failed to load configuration: %v", err)
26
27
	}

28
29
30
	// Validate configuration
	if err := cfg.Agent.Validate(); err != nil {
		log.Fatalf("Invalid configuration: %v", err)
31
32
33
	}

	// Create discovery client
34
	discoveryClient, err := checkpoint.NewDiscoveryClient()
35
36
37
38
39
40
	if err != nil {
		log.Fatalf("Failed to create discovery client: %v", err)
	}
	defer discoveryClient.Close()

	// Create checkpointer
41
	checkpointer := checkpoint.NewCheckpointer(discoveryClient)
42
43
44
45
46
47
48
49
50

	// Context for graceful shutdown
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	// Handle graceful shutdown
	sigChan := make(chan os.Signal, 1)
	signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)

51
52
53
	log.Printf("CRIU Node Agent starting (node: %s)", cfg.Agent.NodeName)
	log.Printf("Checkpoint directory: %s", cfg.Checkpoint.BasePath)
	log.Printf("Signal source: %s", cfg.Agent.SignalSource)
54

55
	switch cfg.Agent.GetSignalSource() {
56
	case SignalFromHTTP:
57
58
59
60
		serverCfg := httpApiServer.ServerConfig{
			ListenAddr:     cfg.Agent.ListenAddr,
			NodeName:       cfg.Agent.NodeName,
			CheckpointSpec: &cfg.Checkpoint,
61
		}
62
		srv := httpApiServer.NewServer(serverCfg, checkpointer)
63
64
65
66
67
68

		// Handle graceful shutdown
		go func() {
			<-sigChan
			shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
			defer shutdownCancel()
69
			if err := srv.Shutdown(shutdownCtx); err != nil {
70
71
72
73
				log.Printf("HTTP server shutdown error: %v", err)
			}
		}()

74
		if err := srv.Start(); err != http.ErrServerClosed {
75
76
77
78
			log.Fatalf("HTTP server error: %v", err)
		}

	case SignalFromWatcher:
79
80
81
82
83
		watcherConfig := watcher.WatcherConfig{
			NodeName:            cfg.Agent.NodeName,
			ListenAddr:          cfg.Agent.ListenAddr,
			RestrictedNamespace: cfg.Agent.RestrictedNamespace,
			CheckpointSpec:      &cfg.Checkpoint,
84
85
86
87
88
89
90
91
92
93
94
95
96
97
		}

		podWatcher, err := watcher.NewWatcher(watcherConfig, discoveryClient, checkpointer)
		if err != nil {
			log.Fatalf("Failed to create pod watcher: %v", err)
		}

		// Handle graceful shutdown
		go func() {
			<-sigChan
			log.Println("Shutting down pod watcher...")
			cancel()
		}()

98
99
		log.Printf("Pod watcher started (watching for label: %s=true)", checkpoint.KubeLabelCheckpointSource)
		log.Printf("Health check endpoint: http://0.0.0.0%s/health", cfg.Agent.ListenAddr)
100
101
102
103
		if err := podWatcher.Start(ctx); err != nil {
			log.Printf("Pod watcher error: %v", err)
		}

104
105
	default:
		log.Fatalf("Unknown signal source: %s", cfg.Agent.SignalSource)
106
107
	}

108
	log.Println("Agent stopped")
109
}