shared.go 3.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
/*
Copyright 2025 NVIDIA Corporation.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

// Package disagg implements disaggregated prefill/decode serving plugins for Dynamo EPP.
//
// The disaggregated architecture splits inference into two phases:
//   - Prefill: processes the input prompt (compute-heavy, parallelizable)
//   - Decode: generates tokens autoregressively (memory-bound, sequential)
//
// This package provides three plugins:
//   - DisaggProfileHandler: orchestrates prefill→decode profile execution
//   - DynPrefillScorer: selects prefill workers via Dynamo FFI
//   - DynDecodeScorer: selects decode workers via Dynamo FFI
package disagg

import (
	"encoding/json"
	"fmt"

	"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
	schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"

	dynscorer "github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/dynamo_kv_scorer"
)

const (
	PrefillProfileName = "prefill"
	DecodeProfileName  = "decode"

43
44
45
46
47
48
49
50
51
	// PrefillEnabledStateKey tracks whether this request should use disaggregated routing.
	// Initially set to true by DisaggProfileHandler.Pick() if a "prefill" scheduling
	// profile exists in the EPP config. Overwritten to false per-request in two cases:
	//   - DisaggProfileHandler.Pick(): prefill profile result is nil (no prefill pods
	//     passed the label-filter).
	//   - DynPrefillScorer.Score(): prefill FFI routing failed (prefill router not yet
	//     activated, e.g., worker registered in K8s but not yet in Dynamo discovery).
	// The decode scorer reads this to decide whether to use overlap_score_weight=0
	// (disaggregated) or normal KV cache overlap scoring (aggregated).
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
	PrefillEnabledStateKey = plugins.StateKey("disagg-prefill-enabled")
)

// PrefillEnabledState stores whether prefill is enabled for the current scheduling cycle.
// Written by DisaggProfileHandler, read by PrefillScorer and DecodeScorer.
type PrefillEnabledState struct {
	Enabled bool
}

// Clone implements plugins.StateData.
func (s *PrefillEnabledState) Clone() plugins.StateData {
	return &PrefillEnabledState{Enabled: s.Enabled}
}

// readPrefillEnabled reads the PrefillEnabledState from CycleState.
// Returns false if the state is not found or not set.
func readPrefillEnabled(cycleState *schedtypes.CycleState) bool {
	state, err := schedtypes.ReadCycleStateKey[*PrefillEnabledState](cycleState, PrefillEnabledStateKey)
	if err == nil && state != nil {
		return state.Enabled
	}
	return false
}

// buildRequestJSON builds an OpenAI-compatible JSON string from a GAIE LLMRequest.
func buildRequestJSON(req *schedtypes.LLMRequest) (string, error) {
	requestBody, err := dynscorer.BuildOpenAIRequest(req)
	if err != nil {
		return "", fmt.Errorf("failed to build OpenAI request: %w", err)
	}
	data, err := json.Marshal(requestBody)
	if err != nil {
		return "", fmt.Errorf("failed to marshal request JSON: %w", err)
	}
	return string(data), nil
}

// serializePods converts pods to a JSON string for the FFI filter.
// Returns an empty string if serialization fails or pods is empty.
func serializePods(pods []schedtypes.Pod) string {
	if len(pods) == 0 {
		return ""
	}
	pj, err := dynscorer.SerializePodsToJSON(pods)
	if err != nil {
		return ""
	}
	return pj
}

// uniformScores returns a score map with the same score for every pod.
func uniformScores(pods []schedtypes.Pod, score float64) map[schedtypes.Pod]float64 {
	out := make(map[schedtypes.Pod]float64, len(pods))
	for _, p := range pods {
		out[p] = score
	}
	return out
}