runtime_config.rs 3.62 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
// SPDX-License-Identifier: Apache-2.0

use std::collections::HashMap;

6
use serde::{Deserialize, Serialize, de::DeserializeOwned};
7

8
9
use crate::protocols::tensor;

10
11
12
13
14
15
16
17
18
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct DisaggregatedEndpoint {
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub bootstrap_host: Option<String>,

    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub bootstrap_port: Option<u16>,
}

19
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
20
21
22
23
24
25
26
pub struct ModelRuntimeConfig {
    pub total_kv_blocks: Option<u64>,

    pub max_num_seqs: Option<u64>,

    pub max_num_batched_tokens: Option<u64>,

27
28
29
30
    pub tool_call_parser: Option<String>,

    pub reasoning_parser: Option<String>,

Yan Ru Pei's avatar
Yan Ru Pei committed
31
32
33
34
    /// Total number of data parallel ranks for this worker (1 if DP not enabled)
    #[serde(default = "default_data_parallel_size")]
    pub data_parallel_size: u32,

35
36
    /// Enable worker-local KV indexer for tracking this worker's own KV cache state (default: true)
    #[serde(default = "default_local_indexer")]
37
38
    pub enable_local_indexer: bool,

39
40
41
    /// Mapping of engine-specific runtime configs
    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
    pub runtime_data: HashMap<String, serde_json::Value>,
42
43
44
45
46
47
48
49
50
51

    // Provide tensor model config in the case where the model type is Tensor.
    // Currently use JSON object for convinence, the programmatic way is to
    // define the model config struct as part of the tensor protocol and
    // import it here.
    // [gluo TODO] switch to ModelConfig if desired and workout a way to
    // prepare it in a convinent way, the protobuf library used by tonic
    // doesn't provide JSON parsing.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub tensor_model_config: Option<tensor::TensorModelConfig>,
52
53
54
55

    /// Bootstrap endpoint for disaggregated serving (prefill workers publish this)
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub disaggregated_endpoint: Option<DisaggregatedEndpoint>,
56
57
}

Yan Ru Pei's avatar
Yan Ru Pei committed
58
59
60
61
const fn default_data_parallel_size() -> u32 {
    1
}

62
63
64
65
const fn default_local_indexer() -> bool {
    true
}

Yan Ru Pei's avatar
Yan Ru Pei committed
66
67
68
69
70
71
72
73
74
impl Default for ModelRuntimeConfig {
    fn default() -> Self {
        Self {
            total_kv_blocks: None,
            max_num_seqs: None,
            max_num_batched_tokens: None,
            tool_call_parser: None,
            reasoning_parser: None,
            data_parallel_size: default_data_parallel_size(),
75
            enable_local_indexer: true,
Yan Ru Pei's avatar
Yan Ru Pei committed
76
77
            runtime_data: HashMap::new(),
            tensor_model_config: None,
78
            disaggregated_endpoint: None,
Yan Ru Pei's avatar
Yan Ru Pei committed
79
80
81
82
        }
    }
}

83
84
85
86
87
88
89
90
91
92
93
94
95
96
impl dynamo_kv_router::WorkerConfigLike for ModelRuntimeConfig {
    fn data_parallel_size(&self) -> u32 {
        self.data_parallel_size
    }

    fn max_num_batched_tokens(&self) -> Option<u64> {
        self.max_num_batched_tokens
    }

    fn total_kv_blocks(&self) -> Option<u64> {
        self.total_kv_blocks
    }
}

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
impl ModelRuntimeConfig {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn set_engine_specific<T: Serialize>(&mut self, key: &str, value: T) -> anyhow::Result<()> {
        self.runtime_data
            .insert(key.to_string(), serde_json::to_value(value)?);
        Ok(())
    }

    pub fn get_engine_specific<T: DeserializeOwned>(&self, key: &str) -> anyhow::Result<Option<T>> {
        if let Some(value) = self.runtime_data.get(key) {
            Ok(Some(serde_json::from_value(value.clone())?))
        } else {
            Ok(None)
        }
    }
}