profile_handler.go 5.66 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*
Copyright 2025 NVIDIA Corporation.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package disagg

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"

	log "sigs.k8s.io/controller-runtime/pkg/log"
26
27
28
	logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/observability/logging"
	plugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
	schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
29
30
31
32
33
34
35
)

const (
	DisaggProfileHandlerType = "disagg-profile-handler"
)

// compile-time type assertion
36
var _ schedtypes.ProfileHandler = &DisaggProfileHandler{}
37
38
39
40
41
42
43
44
45
46
47
48

// DisaggProfileHandlerConfig holds the configuration for the DisaggProfileHandler.
type DisaggProfileHandlerConfig struct{}

// DisaggProfileHandlerFactory defines the factory function for DisaggProfileHandler.
func DisaggProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
	cfg := DisaggProfileHandlerConfig{}
	if rawParameters != nil {
		if err := json.Unmarshal(rawParameters, &cfg); err != nil {
			return nil, fmt.Errorf("failed to parse %s plugin parameters: %w", DisaggProfileHandlerType, err)
		}
	}
49
50
	enforceDisagg := getEnvBoolOrDefault("DYN_ENFORCE_DISAGG", false)
	return NewDisaggProfileHandler(enforceDisagg).WithName(name), nil
51
52
53
}

// NewDisaggProfileHandler initializes a new DisaggProfileHandler.
54
func NewDisaggProfileHandler(enforceDisagg bool) *DisaggProfileHandler {
55
	return &DisaggProfileHandler{
56
57
		typedName:     plugins.TypedName{Type: DisaggProfileHandlerType, Name: DisaggProfileHandlerType},
		enforceDisagg: enforceDisagg,
58
59
60
61
62
	}
}

// DisaggProfileHandler is a ProfileHandler that orchestrates prefill/decode disaggregated serving.
type DisaggProfileHandler struct {
63
64
	typedName     plugins.TypedName
	enforceDisagg bool
65
66
67
68
69
70
71
72
73
74
75
76
77
78
}

// TypedName returns the type and name tuple of this plugin instance.
func (h *DisaggProfileHandler) TypedName() plugins.TypedName {
	return h.typedName
}

// WithName sets the name of the profile handler.
func (h *DisaggProfileHandler) WithName(name string) *DisaggProfileHandler {
	h.typedName.Name = name
	return h
}

// Pick selects which profiles to run in the current iteration.
79
80
func (h *DisaggProfileHandler) Pick(ctx context.Context, cycleState *schedtypes.CycleState, _ *schedtypes.InferenceRequest,
	profiles map[string]schedtypes.SchedulerProfile, profileResults map[string]*schedtypes.ProfileRunResult) map[string]schedtypes.SchedulerProfile {
81
82
83
84
85
86
87
88
89
90

	logger := log.FromContext(ctx).V(logutil.VERBOSE)

	if len(profileResults) == 0 {
		_, prefillExists := profiles[PrefillProfileName]
		state := &PrefillEnabledState{Enabled: prefillExists}
		cycleState.Write(PrefillEnabledStateKey, state)
		logger.Info("DisaggProfileHandler: prefill enabled state determined", "prefillEnabled", prefillExists)

		if prefillExists {
91
			return map[string]schedtypes.SchedulerProfile{
92
93
94
95
				PrefillProfileName: profiles[PrefillProfileName],
			}
		}
		if decodeProfile, ok := profiles[DecodeProfileName]; ok {
96
			return map[string]schedtypes.SchedulerProfile{
97
98
99
100
101
102
103
104
105
				DecodeProfileName: decodeProfile,
			}
		}
		return profiles
	}

	if prefillResult, prefillDone := profileResults[PrefillProfileName]; prefillDone {
		if _, decodeDone := profileResults[DecodeProfileName]; !decodeDone {
			if prefillResult == nil {
106
107
				if h.enforceDisagg {
					logger.Info("DisaggProfileHandler: prefill profile failed and enforce_disagg=true, rejecting request")
108
					return map[string]schedtypes.SchedulerProfile{}
109
				}
110
111
112
113
114
				logger.Info("DisaggProfileHandler: prefill profile failed (no workers?), falling back to aggregated decode")
				cycleState.Write(PrefillEnabledStateKey, &PrefillEnabledState{Enabled: false})
			}

			if decodeProfile, ok := profiles[DecodeProfileName]; ok {
115
				return map[string]schedtypes.SchedulerProfile{
116
117
118
119
120
121
					DecodeProfileName: decodeProfile,
				}
			}
		}
	}

122
	return map[string]schedtypes.SchedulerProfile{}
123
124
125
}

// ProcessResults aggregates the profile run results and designates the primary profile.
126
func (h *DisaggProfileHandler) ProcessResults(_ context.Context, _ *schedtypes.CycleState, req *schedtypes.InferenceRequest,
127
128
	profileResults map[string]*schedtypes.ProfileRunResult) (*schedtypes.SchedulingResult, error) {

129
130
131
132
	if h.enforceDisagg && (req.Headers == nil || req.Headers[PrefillWorkerIDHeader] == "") {
		if _, prefillRan := profileResults[PrefillProfileName]; prefillRan {
			return nil, errors.New(
				"disaggregated mode enforced (DYN_ENFORCE_DISAGG=true) but prefill workers " +
133
					"are not available; request rejected")
134
135
136
		}
	}

137
138
139
140
141
142
143
144
145
146
147
148
149
	if len(profileResults) == 0 {
		return nil, errors.New("disagg profile handler received no profile results")
	}

	primaryProfile := DecodeProfileName
	if _, ok := profileResults[DecodeProfileName]; !ok {
		for name := range profileResults {
			primaryProfile = name
			break
		}
	}

	if profileResults[primaryProfile] == nil {
150
151
152
		if h.enforceDisagg {
			return nil, errors.New(
				"disaggregated mode enforced (DYN_ENFORCE_DISAGG=true) but prefill workers " +
153
					"are not available; request rejected")
154
		}
155
156
157
158
159
160
161
162
		return nil, fmt.Errorf("primary profile '%s' failed to produce a result", primaryProfile)
	}

	return &schedtypes.SchedulingResult{
		ProfileResults:     profileResults,
		PrimaryProfileName: primaryProfile,
	}, nil
}