config.go 3.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
// config.go provides configuration loading for the checkpoint agent.
package main

import (
	"fmt"
	"os"

	"gopkg.in/yaml.v3"

	"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint"
)

// ConfigMapPath is the default path where the ConfigMap is mounted.
const ConfigMapPath = "/etc/chrek/config.yaml"

// CheckpointSignalSource determines how checkpoint operations are triggered.
type CheckpointSignalSource string

const (
	// SignalFromHTTP triggers checkpoints via HTTP API requests.
	SignalFromHTTP CheckpointSignalSource = "http"
	// SignalFromWatcher triggers checkpoints automatically when pods become Ready.
	SignalFromWatcher CheckpointSignalSource = "watcher"
)

// FullConfig is the root configuration structure loaded from the ConfigMap.
type FullConfig struct {
	Agent      AgentConfig               `yaml:"agent"`
	Checkpoint checkpoint.CheckpointSpec `yaml:"checkpoint"`
}

// AgentConfig holds the runtime configuration for the checkpoint agent daemon.
type AgentConfig struct {
	// SignalSource determines how checkpoints are triggered: "http" or "watcher"
	SignalSource string `yaml:"signalSource"`

	// ListenAddr is the HTTP server address for health checks and API
	ListenAddr string `yaml:"listenAddr"`

	// NodeName is the Kubernetes node name (from NODE_NAME env, downward API)
	NodeName string `yaml:"-"`

	// RestrictedNamespace restricts pod watching to this namespace (optional)
	RestrictedNamespace string `yaml:"-"`
}

// ConfigError represents a configuration validation error.
type ConfigError struct {
	Field   string
	Message string
}

func (e *ConfigError) Error() string {
	return fmt.Sprintf("config error: %s: %s", e.Field, e.Message)
}

// LoadConfig loads the full configuration from a YAML file.
func LoadConfig(path string) (*FullConfig, error) {
	data, err := os.ReadFile(path)
	if err != nil {
		return nil, fmt.Errorf("failed to read config file %s: %w", path, err)
	}

	cfg := &FullConfig{}
	if err := yaml.Unmarshal(data, cfg); err != nil {
		return nil, fmt.Errorf("failed to parse config file %s: %w", path, err)
	}

	// Apply environment variable overrides
	cfg.Agent.loadEnvOverrides()

	return cfg, nil
}

// LoadConfigOrDefault loads configuration from a file, falling back to zero values if the file doesn't exist.
func LoadConfigOrDefault(path string) (*FullConfig, error) {
	cfg, err := LoadConfig(path)
	if err != nil {
		if os.IsNotExist(err) {
			cfg = &FullConfig{}
			cfg.Agent.loadEnvOverrides()
			return cfg, nil
		}
		return nil, err
	}
	return cfg, nil
}

// loadEnvOverrides applies environment variable overrides to the AgentConfig.
func (c *AgentConfig) loadEnvOverrides() {
	if v := os.Getenv("NODE_NAME"); v != "" {
		c.NodeName = v
	}
	if v := os.Getenv("RESTRICTED_NAMESPACE"); v != "" {
		c.RestrictedNamespace = v
	}
}

// GetSignalSource returns the signal source as a CheckpointSignalSource type.
func (c *AgentConfig) GetSignalSource() CheckpointSignalSource {
	return CheckpointSignalSource(c.SignalSource)
}

// Validate checks that the AgentConfig has valid values.
func (c *AgentConfig) Validate() error {
	if c.SignalSource != string(SignalFromHTTP) && c.SignalSource != string(SignalFromWatcher) {
		return &ConfigError{
			Field:   "signalSource",
			Message: "must be 'http' or 'watcher'",
		}
	}
	if c.SignalSource == string(SignalFromHTTP) && c.ListenAddr == "" {
		return &ConfigError{
			Field:   "listenAddr",
			Message: "cannot be empty when signalSource is 'http'",
		}
	}
	return nil
}

// Validate validates the full configuration.
func (c *FullConfig) Validate() error {
	if err := c.Agent.Validate(); err != nil {
		return err
	}
	if err := c.Checkpoint.Validate(); err != nil {
		return err
	}
	return nil
}