discovery.rs 5.39 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Ryan Olson's avatar
Ryan Olson committed
16
use std::sync::Arc;
17

Ryan Olson's avatar
Ryan Olson committed
18
use serde::{Deserialize, Serialize};
19
use tokio::sync::mpsc::Receiver;
Ryan Olson's avatar
Ryan Olson committed
20

Neelay Shah's avatar
Neelay Shah committed
21
use triton_distributed_runtime::{
Ryan Olson's avatar
Ryan Olson committed
22
    protocols::{self, annotated::Annotated},
23
    raise,
Ryan Olson's avatar
Ryan Olson committed
24
25
    transports::etcd::{KeyValue, WatchEvent},
    DistributedRuntime, Result,
26
};
Ryan Olson's avatar
Ryan Olson committed
27
28

use super::ModelManager;
29
use crate::model_type::ModelType;
Ryan Olson's avatar
Ryan Olson committed
30
31
use crate::protocols::openai::chat_completions::{
    ChatCompletionRequest, ChatCompletionResponseDelta,
32
};
33
34
use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse};
use tracing;
Ryan Olson's avatar
Ryan Olson committed
35
36
37
38
39
40
41
42
43
44
45
/// [ModelEntry] is a struct that contains the information for the HTTP service to discover models
/// from the etcd cluster.
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelEntry {
    /// Public name of the model
    /// This will be used to identify the model in the HTTP service and the value used in an
    /// an [OAI ChatRequest][crate::protocols::openai::chat_completions::ChatCompletionRequest].
    pub name: String,

    /// Component of the endpoint.
    pub endpoint: protocols::Endpoint,
46
47
48

    /// Specifies whether the model is a chat or completion model.s
    pub model_type: ModelType,
Ryan Olson's avatar
Ryan Olson committed
49
50
51
52
}

pub struct ModelWatchState {
    pub prefix: String,
53
    pub model_type: ModelType,
Ryan Olson's avatar
Ryan Olson committed
54
55
56
57
58
    pub manager: ModelManager,
    pub drt: DistributedRuntime,
}

pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<WatchEvent>) {
59
    tracing::debug!("model watcher started");
Ryan Olson's avatar
Ryan Olson committed
60
61
62
63
64
65

    let mut events_rx = events_rx;

    while let Some(event) = events_rx.recv().await {
        match event {
            WatchEvent::Put(kv) => match handle_put(&kv, state.clone()).await {
66
67
                Ok((model_name, model_type)) => {
                    tracing::info!("added {} model: {}", model_type, model_name);
Ryan Olson's avatar
Ryan Olson committed
68
69
                }
                Err(e) => {
70
                    tracing::error!("error adding model: {}", e);
Ryan Olson's avatar
Ryan Olson committed
71
72
73
                }
            },
            WatchEvent::Delete(kv) => match handle_delete(&kv, state.clone()).await {
74
75
                Ok((model_name, model_type)) => {
                    tracing::info!("removed {} model: {}", model_type, model_name);
Ryan Olson's avatar
Ryan Olson committed
76
77
                }
                Err(e) => {
78
                    tracing::error!("error removing model: {}", e);
Ryan Olson's avatar
Ryan Olson committed
79
80
81
82
83
                }
            },
        }
    }

84
    tracing::debug!("model watcher stopped");
Ryan Olson's avatar
Ryan Olson committed
85
86
}

87
async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ModelType)> {
88
    tracing::debug!("removing model");
Ryan Olson's avatar
Ryan Olson committed
89
90

    let key = kv.key_str()?;
91
    tracing::debug!("key: {}", key);
Ryan Olson's avatar
Ryan Olson committed
92
93

    let model_name = key.trim_start_matches(&state.prefix);
94
95
96
97
98
99
100

    match state.model_type {
        ModelType::Chat => state.manager.remove_chat_completions_model(model_name)?,
        ModelType::Completion => state.manager.remove_completions_model(model_name)?,
    };

    Ok((model_name, state.model_type))
Ryan Olson's avatar
Ryan Olson committed
101
102
103
104
105
106
}

// Handles a PUT event from etcd, this usually means adding a new model to the list of served
// models.
//
// If this method errors, for the near term, we will delete the offending key.
107
async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ModelType)> {
108
    tracing::debug!("adding model");
Ryan Olson's avatar
Ryan Olson committed
109
110

    let key = kv.key_str()?;
111
    tracing::debug!("key: {}", key);
Ryan Olson's avatar
Ryan Olson committed
112

113
    let model_name = key.trim_start_matches(&state.prefix);
Ryan Olson's avatar
Ryan Olson committed
114
115
116
117
118
119
120
121
122
    let model_entry = serde_json::from_slice::<ModelEntry>(kv.value())?;

    if model_entry.name != model_name {
        raise!(
            "model name mismatch: {} != {}",
            model_entry.name,
            model_name
        );
    }
123
124
125
126
127
128
129
    if model_entry.model_type != state.model_type {
        raise!(
            "model type mismatch: {} != {}",
            model_entry.model_type,
            state.model_type
        );
    }
Ryan Olson's avatar
Ryan Olson committed
130

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    match state.model_type {
        ModelType::Chat => {
            let client = state
                .drt
                .namespace(model_entry.endpoint.namespace)?
                .component(model_entry.endpoint.component)?
                .endpoint(model_entry.endpoint.name)
                .client::<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>()
                .await?;
            state
                .manager
                .add_chat_completions_model(model_name, Arc::new(client))?;
        }
        ModelType::Completion => {
            let client = state
                .drt
                .namespace(model_entry.endpoint.namespace)?
                .component(model_entry.endpoint.component)?
                .endpoint(model_entry.endpoint.name)
                .client::<CompletionRequest, Annotated<CompletionResponse>>()
                .await?;
            state
                .manager
                .add_completions_model(model_name, Arc::new(client))?;
        }
    }
Ryan Olson's avatar
Ryan Olson committed
157

158
    Ok((model_name, state.model_type))
Ryan Olson's avatar
Ryan Olson committed
159
}