shared.go 5.23 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
/*
Copyright 2025 NVIDIA Corporation.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

// Package disagg implements disaggregated prefill/decode serving plugins for Dynamo EPP.
//
// The disaggregated architecture splits inference into two phases:
//   - Prefill: processes the input prompt (compute-heavy, parallelizable)
//   - Decode: generates tokens autoregressively (memory-bound, sequential)
//
// This package provides three plugins:
//   - DisaggProfileHandler: orchestrates prefill→decode profile execution
//   - DynPrefillScorer: selects prefill workers via Dynamo FFI
//   - DynDecodeScorer: selects decode workers via Dynamo FFI
package disagg

import (
	"encoding/json"
	"fmt"
32
33
	"os"
	"strings"
34

35
36
37
38
39
	"github.com/go-logr/logr"
	logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/observability/logging"
	plugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
	fwkrh "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requesthandling"
	schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
40
41
42
43
44
45
46
47

	dynscorer "github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/dynamo_kv_scorer"
)

const (
	PrefillProfileName = "prefill"
	DecodeProfileName  = "decode"

48
	// PrefillEnabledStateKey tracks whether this request should use disaggregated routing.
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
	PrefillEnabledStateKey = plugins.StateKey("disagg-prefill-enabled")
)

// PrefillEnabledState stores whether prefill is enabled for the current scheduling cycle.
type PrefillEnabledState struct {
	Enabled bool
}

// Clone implements plugins.StateData.
func (s *PrefillEnabledState) Clone() plugins.StateData {
	return &PrefillEnabledState{Enabled: s.Enabled}
}

// readPrefillEnabled reads the PrefillEnabledState from CycleState.
func readPrefillEnabled(cycleState *schedtypes.CycleState) bool {
	state, err := schedtypes.ReadCycleStateKey[*PrefillEnabledState](cycleState, PrefillEnabledStateKey)
	if err == nil && state != nil {
		return state.Enabled
	}
	return false
}

// buildRequestJSON builds an OpenAI-compatible JSON string from a GAIE LLMRequest.
72
func buildRequestJSON(req *schedtypes.InferenceRequest) (string, error) {
73
74
75
76
77
78
79
80
81
82
83
	requestBody, err := dynscorer.BuildOpenAIRequest(req)
	if err != nil {
		return "", fmt.Errorf("failed to build OpenAI request: %w", err)
	}
	data, err := json.Marshal(requestBody)
	if err != nil {
		return "", fmt.Errorf("failed to marshal request JSON: %w", err)
	}
	return string(data), nil
}

84
85
86
// serializeEndpoints converts endpoints to a JSON string for the FFI filter.
func serializeEndpoints(endpoints []schedtypes.Endpoint) string {
	if len(endpoints) == 0 {
87
88
		return ""
	}
89
	pj, err := dynscorer.SerializeEndpointsToJSON(endpoints)
90
91
92
93
94
95
	if err != nil {
		return ""
	}
	return pj
}

96
97
98
99
100
// uniformScores returns a score map with the same score for every endpoint.
func uniformScores(endpoints []schedtypes.Endpoint, score float64) map[schedtypes.Endpoint]float64 {
	out := make(map[schedtypes.Endpoint]float64, len(endpoints))
	for _, ep := range endpoints {
		out[ep] = score
101
102
103
	}
	return out
}
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

// setTokenizedPrompt stores pre-computed token IDs on the LLMRequest and
// injects nvext.token_data into the PayloadMap so it is forwarded to the
// worker in the request body.
//
// The GAIE framework re-serializes the PayloadMap after scheduling/PreRequest
// plugins run (PR #2854), so this mutation is included in the forwarded body.
func setTokenizedPrompt(req *schedtypes.InferenceRequest, tokens []int64, logger logr.Logger) {
	if req == nil || len(tokens) == 0 {
		logger.V(logutil.DEFAULT).Info("[EPP-INJECT] No tokens to inject (empty token list)")
		return
	}

	tokenIDs := make([]uint32, len(tokens))
	for i, t := range tokens {
		tokenIDs[i] = uint32(t)
	}

	req.TokenizedPrompt = &schedtypes.TokenizedPrompt{
		TokenIDs: tokenIDs,
	}

	// Inject into the PayloadMap so the body includes nvext.token_data.
	payloadInjected := false
	if req.Body != nil {
		if pm, ok := req.Body.Payload.(fwkrh.PayloadMap); ok {
			nvext, _ := pm["nvext"].(map[string]any)
			if nvext == nil {
				nvext = map[string]any{}
			}
			nvext["token_data"] = tokenIDs
			pm["nvext"] = nvext
			payloadInjected = true
		}
	}

	if payloadInjected {
		logger.V(logutil.DEFAULT).Info("[EPP-INJECT] Injected pre-computed tokens into request body nvext.token_data",
			"tokenCount", len(tokenIDs),
			"requestId", req.RequestId)
	} else {
		logger.V(logutil.DEFAULT).Error(nil, "[EPP-INJECT] Failed to inject nvext.token_data: Payload is not a PayloadMap — sidecar will re-tokenize",
			"tokenCount", len(tokenIDs),
			"requestId", req.RequestId)
	}
}

func getEnvBoolOrDefault(key string, def bool) bool {
	if v := os.Getenv(key); v != "" {
		switch strings.ToLower(v) {
		case "true", "1", "yes", "on":
			return true
		case "false", "0", "no", "off":
			return false
		}
	}
	return def
}