Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
537759f1
Unverified
Commit
537759f1
authored
Aug 15, 2025
by
Abrar Shivani
Committed by
GitHub
Aug 15, 2025
Browse files
feat: Dynamic Endpoint Exposure Based on Model Type (#1447)
parent
a4e06895
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
411 additions
and
54 deletions
+411
-54
lib/llm/src/discovery.rs
lib/llm/src/discovery.rs
+1
-1
lib/llm/src/discovery/watcher.rs
lib/llm/src/discovery/watcher.rs
+103
-3
lib/llm/src/endpoint_type.rs
lib/llm/src/endpoint_type.rs
+49
-0
lib/llm/src/entrypoint/input/http.rs
lib/llm/src/entrypoint/input/http.rs
+82
-6
lib/llm/src/http/service/service_v2.rs
lib/llm/src/http/service/service_v2.rs
+136
-39
lib/llm/src/lib.rs
lib/llm/src/lib.rs
+1
-0
lib/llm/src/model_type.rs
lib/llm/src/model_type.rs
+9
-0
lib/llm/tests/http-service.rs
lib/llm/tests/http-service.rs
+30
-5
No files found.
lib/llm/src/discovery.rs
View file @
537759f1
...
...
@@ -8,7 +8,7 @@ mod model_entry;
pub
use
model_entry
::
ModelEntry
;
mod
watcher
;
pub
use
watcher
::
ModelWatcher
;
pub
use
watcher
::
{
ModelUpdate
,
ModelWatcher
}
;
/// The root etcd path for ModelEntry
pub
const
MODEL_ROOT_PATH
:
&
str
=
"models"
;
lib/llm/src/discovery/watcher.rs
View file @
537759f1
...
...
@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use
std
::
sync
::
Arc
;
use
tokio
::
sync
::
mpsc
::
Sender
;
use
anyhow
::
Context
as
_
;
use
tokio
::
sync
::{
mpsc
::
Receiver
,
Notify
};
...
...
@@ -36,14 +37,24 @@ use crate::{
use
super
::{
ModelEntry
,
ModelManager
,
MODEL_ROOT_PATH
};
#[derive(Debug,
Clone,
Copy,
PartialEq)]
pub
enum
ModelUpdate
{
Added
(
ModelType
),
Removed
(
ModelType
),
}
pub
struct
ModelWatcher
{
manager
:
Arc
<
ModelManager
>
,
drt
:
DistributedRuntime
,
router_mode
:
RouterMode
,
notify_on_model
:
Notify
,
model_update_tx
:
Option
<
Sender
<
ModelUpdate
>>
,
kv_router_config
:
Option
<
KvRouterConfig
>
,
}
const
ALL_MODEL_TYPES
:
&
[
ModelType
]
=
&
[
ModelType
::
Chat
,
ModelType
::
Completion
,
ModelType
::
Embedding
];
impl
ModelWatcher
{
pub
fn
new
(
runtime
:
DistributedRuntime
,
...
...
@@ -56,10 +67,15 @@ impl ModelWatcher {
drt
:
runtime
,
router_mode
,
notify_on_model
:
Notify
::
new
(),
model_update_tx
:
None
,
kv_router_config
,
}
}
pub
fn
set_notify_on_model_update
(
&
mut
self
,
tx
:
Sender
<
ModelUpdate
>
)
{
self
.model_update_tx
=
Some
(
tx
);
}
/// Wait until we have at least one chat completions model and return it's name.
pub
async
fn
wait_for_chat_model
(
&
self
)
->
String
{
// Loop in case it gets added and immediately deleted
...
...
@@ -100,6 +116,12 @@ impl ModelWatcher {
};
self
.manager
.save_model_entry
(
key
,
model_entry
.clone
());
if
let
Some
(
tx
)
=
&
self
.model_update_tx
{
tx
.send
(
ModelUpdate
::
Added
(
model_entry
.model_type
))
.await
.ok
();
}
if
self
.manager
.has_model_any
(
&
model_entry
.name
)
{
tracing
::
trace!
(
name
=
model_entry
.name
,
"New endpoint for existing model"
);
self
.notify_on_model
.notify_waiters
();
...
...
@@ -151,13 +173,91 @@ impl ModelWatcher {
.await
.with_context
(||
model_name
.clone
())
?
;
if
!
active_instances
.is_empty
()
{
let
mut
update_tx
=
true
;
let
mut
model_type
:
ModelType
=
model_entry
.model_type
;
if
model_entry
.model_type
==
ModelType
::
Chat
&&
self
.manager
.list_chat_completions_models
()
.is_empty
()
{
self
.manager
.remove_chat_completions_model
(
&
model_name
)
.ok
();
model_type
=
ModelType
::
Chat
;
}
else
if
model_entry
.model_type
==
ModelType
::
Completion
&&
self
.manager
.list_completions_models
()
.is_empty
()
{
self
.manager
.remove_completions_model
(
&
model_name
)
.ok
();
model_type
=
ModelType
::
Completion
;
}
else
if
model_entry
.model_type
==
ModelType
::
Embedding
&&
self
.manager
.list_embeddings_models
()
.is_empty
()
{
self
.manager
.remove_embeddings_model
(
&
model_name
)
.ok
();
model_type
=
ModelType
::
Embedding
;
}
else
if
model_entry
.model_type
==
ModelType
::
Backend
{
if
self
.manager
.list_chat_completions_models
()
.is_empty
()
{
self
.manager
.remove_chat_completions_model
(
&
model_name
)
.ok
();
model_type
=
ModelType
::
Chat
;
}
if
self
.manager
.list_completions_models
()
.is_empty
()
{
self
.manager
.remove_completions_model
(
&
model_name
)
.ok
();
if
model_type
==
ModelType
::
Chat
{
model_type
=
ModelType
::
Backend
;
}
else
{
model_type
=
ModelType
::
Completion
;
}
}
}
else
{
tracing
::
debug!
(
"Model {} is still active in other instances, not removing"
,
model_name
);
update_tx
=
false
;
}
if
update_tx
{
if
let
Some
(
tx
)
=
&
self
.model_update_tx
{
tx
.send
(
ModelUpdate
::
Removed
(
model_type
))
.await
.ok
();
}
}
return
Ok
(
None
);
}
// Ignore the errors because model could be either type
let
_
=
self
.manager
.remove_chat_completions_model
(
&
model_name
);
let
_
=
self
.manager
.remove_completions_model
(
&
model_name
);
let
_
=
self
.manager
.remove_embeddings_model
(
&
model_name
);
let
chat_model_remove_err
=
self
.manager
.remove_chat_completions_model
(
&
model_name
);
let
completions_model_remove_err
=
self
.manager
.remove_completions_model
(
&
model_name
);
let
embeddings_model_remove_err
=
self
.manager
.remove_embeddings_model
(
&
model_name
);
let
mut
chat_model_removed
=
false
;
let
mut
completions_model_removed
=
false
;
let
mut
embeddings_model_removed
=
false
;
if
chat_model_remove_err
.is_ok
()
&&
self
.manager
.list_chat_completions_models
()
.is_empty
()
{
chat_model_removed
=
true
;
}
if
completions_model_remove_err
.is_ok
()
&&
self
.manager
.list_completions_models
()
.is_empty
()
{
completions_model_removed
=
true
;
}
if
embeddings_model_remove_err
.is_ok
()
&&
self
.manager
.list_embeddings_models
()
.is_empty
()
{
embeddings_model_removed
=
true
;
}
if
!
chat_model_removed
&&
!
completions_model_removed
&&
!
embeddings_model_removed
{
tracing
::
debug!
(
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}"
,
model_name
,
chat_model_removed
,
completions_model_removed
,
embeddings_model_removed
);
}
else
{
for
model_type
in
ALL_MODEL_TYPES
{
if
(
chat_model_removed
&&
*
model_type
==
ModelType
::
Chat
)
||
(
completions_model_removed
&&
*
model_type
==
ModelType
::
Completion
)
||
(
embeddings_model_removed
&&
*
model_type
==
ModelType
::
Embedding
)
{
if
let
Some
(
tx
)
=
&
self
.model_update_tx
{
tx
.send
(
ModelUpdate
::
Removed
(
*
model_type
))
.await
.ok
();
}
}
}
}
Ok
(
Some
(
model_name
))
}
...
...
lib/llm/src/endpoint_type.rs
0 → 100644
View file @
537759f1
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
serde
::{
Deserialize
,
Serialize
};
use
strum
::
Display
;
#[derive(Copy,
Debug,
Clone,
Display,
Serialize,
Deserialize,
Eq,
PartialEq,
Hash)]
pub
enum
EndpointType
{
// Chat Completions API
Chat
,
/// Older completions API
Completion
,
/// Embeddings API
Embedding
,
/// Responses API
Responses
,
}
impl
EndpointType
{
pub
fn
as_str
(
&
self
)
->
&
str
{
match
self
{
Self
::
Chat
=>
"chat"
,
Self
::
Completion
=>
"completion"
,
Self
::
Embedding
=>
"embedding"
,
Self
::
Responses
=>
"responses"
,
}
}
pub
fn
all
()
->
Vec
<
Self
>
{
vec!
[
Self
::
Chat
,
Self
::
Completion
,
Self
::
Embedding
,
Self
::
Responses
,
]
}
}
lib/llm/src/entrypoint/input/http.rs
View file @
537759f1
...
...
@@ -4,11 +4,13 @@
use
std
::
sync
::
Arc
;
use
crate
::{
discovery
::{
ModelManager
,
ModelWatcher
,
MODEL_ROOT_PATH
},
discovery
::{
ModelManager
,
ModelUpdate
,
ModelWatcher
,
MODEL_ROOT_PATH
},
endpoint_type
::
EndpointType
,
engines
::
StreamingEngineAdapter
,
entrypoint
::{
self
,
input
::
common
,
EngineConfig
},
http
::
service
::
service_v2
,
http
::
service
::
service_v2
::{
self
,
HttpService
}
,
kv_router
::
KvRouterConfig
,
model_type
::
ModelType
,
types
::
openai
::{
chat_completions
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
},
completions
::{
NvCreateCompletionRequest
,
NvCreateCompletionResponse
},
...
...
@@ -22,9 +24,6 @@ use dynamo_runtime::{DistributedRuntime, Runtime};
pub
async
fn
run
(
runtime
:
Runtime
,
engine_config
:
EngineConfig
)
->
anyhow
::
Result
<
()
>
{
let
mut
http_service_builder
=
service_v2
::
HttpService
::
builder
()
.port
(
engine_config
.local_model
()
.http_port
())
.enable_chat_endpoints
(
true
)
.enable_cmpl_endpoints
(
true
)
.enable_embeddings_endpoints
(
true
)
.with_request_template
(
engine_config
.local_model
()
.request_template
());
let
http_service
=
match
engine_config
{
...
...
@@ -45,6 +44,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
MODEL_ROOT_PATH
,
router_config
.router_mode
,
Some
(
router_config
.kv_router_config
),
Arc
::
new
(
http_service
.clone
()),
)
.await
?
;
}
...
...
@@ -98,6 +98,12 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
.await
?
;
manager
.add_completions_model
(
local_model
.display_name
(),
completions_engine
)
?
;
for
endpoint_type
in
EndpointType
::
all
()
{
http_service
.enable_model_endpoint
(
endpoint_type
,
true
)
.await
;
}
http_service
}
EngineConfig
::
StaticFull
{
engine
,
model
,
..
}
=>
{
...
...
@@ -106,6 +112,13 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let
manager
=
http_service
.model_manager
();
manager
.add_completions_model
(
model
.service_name
(),
engine
.clone
())
?
;
manager
.add_chat_completions_model
(
model
.service_name
(),
engine
)
?
;
// Enable all endpoints
for
endpoint_type
in
EndpointType
::
all
()
{
http_service
.enable_model_endpoint
(
endpoint_type
,
true
)
.await
;
}
http_service
}
EngineConfig
::
StaticCore
{
...
...
@@ -129,6 +142,12 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
>
(
model
.card
(),
inner_engine
)
.await
?
;
manager
.add_completions_model
(
model
.service_name
(),
cmpl_pipeline
)
?
;
// Enable all endpoints
for
endpoint_type
in
EndpointType
::
all
()
{
http_service
.enable_model_endpoint
(
endpoint_type
,
true
)
.await
;
}
http_service
}
};
...
...
@@ -154,13 +173,70 @@ async fn run_watcher(
network_prefix
:
&
str
,
router_mode
:
RouterMode
,
kv_router_config
:
Option
<
KvRouterConfig
>
,
http_service
:
Arc
<
HttpService
>
,
)
->
anyhow
::
Result
<
()
>
{
let
watch_obj
=
ModelWatcher
::
new
(
runtime
,
model_manager
,
router_mode
,
kv_router_config
);
let
mut
watch_obj
=
ModelWatcher
::
new
(
runtime
,
model_manager
,
router_mode
,
kv_router_config
);
tracing
::
info!
(
"Watching for remote model at {network_prefix}"
);
let
models_watcher
=
etcd_client
.kv_get_and_watch_prefix
(
network_prefix
)
.await
?
;
let
(
_
prefix
,
_
watcher
,
receiver
)
=
models_watcher
.dissolve
();
// Create a channel to receive model type updates
let
(
tx
,
mut
rx
)
=
tokio
::
sync
::
mpsc
::
channel
(
32
);
watch_obj
.set_notify_on_model_update
(
tx
);
// Spawn a task to watch for model type changes and update HTTP service endpoints
let
_
endpoint_enabler_task
=
tokio
::
spawn
(
async
move
{
while
let
Some
(
model_type
)
=
rx
.recv
()
.await
{
tracing
::
debug!
(
"Received model type update: {:?}"
,
model_type
);
update_http_endpoints
(
http_service
.clone
(),
model_type
)
.await
;
}
});
// Pass the sender to the watcher
let
_
watcher_task
=
tokio
::
spawn
(
async
move
{
watch_obj
.watch
(
receiver
)
.await
;
});
Ok
(())
}
/// Updates HTTP service endpoints based on available model types
async
fn
update_http_endpoints
(
service
:
Arc
<
HttpService
>
,
model_type
:
ModelUpdate
)
{
tracing
::
debug!
(
"Updating HTTP service endpoints for model type: {:?}"
,
model_type
);
match
model_type
{
ModelUpdate
::
Added
(
model_type
)
=>
match
model_type
{
ModelType
::
Backend
=>
{
service
.enable_model_endpoint
(
EndpointType
::
Chat
,
true
)
.await
;
service
.enable_model_endpoint
(
EndpointType
::
Completion
,
true
)
.await
;
}
_
=>
{
service
.enable_model_endpoint
(
model_type
.as_endpoint_type
(),
true
)
.await
;
}
},
ModelUpdate
::
Removed
(
model_type
)
=>
match
model_type
{
ModelType
::
Backend
=>
{
service
.enable_model_endpoint
(
EndpointType
::
Chat
,
false
)
.await
;
service
.enable_model_endpoint
(
EndpointType
::
Completion
,
false
)
.await
;
}
_
=>
{
service
.enable_model_endpoint
(
model_type
.as_endpoint_type
(),
false
)
.await
;
}
},
}
}
lib/llm/src/http/service/service_v2.rs
View file @
537759f1
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
collections
::
HashMap
;
use
std
::
env
::
var
;
use
std
::
sync
::
atomic
::
AtomicBool
;
use
std
::
sync
::
atomic
::
Ordering
;
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
...
...
@@ -9,6 +12,7 @@ use super::metrics;
use
super
::
Metrics
;
use
super
::
RouteDoc
;
use
crate
::
discovery
::
ModelManager
;
use
crate
::
endpoint_type
::
EndpointType
;
use
crate
::
request_template
::
RequestTemplate
;
use
anyhow
::
Result
;
use
derive_builder
::
Builder
;
...
...
@@ -19,10 +23,48 @@ use tokio_util::sync::CancellationToken;
use
tower_http
::
trace
::
TraceLayer
;
/// HTTP service shared state
#[derive(Default)]
pub
struct
State
{
metrics
:
Arc
<
Metrics
>
,
manager
:
Arc
<
ModelManager
>
,
etcd_client
:
Option
<
etcd
::
Client
>
,
flags
:
StateFlags
,
}
#[derive(Default,
Debug)]
struct
StateFlags
{
chat_endpoints_enabled
:
AtomicBool
,
cmpl_endpoints_enabled
:
AtomicBool
,
embeddings_endpoints_enabled
:
AtomicBool
,
responses_endpoints_enabled
:
AtomicBool
,
}
impl
StateFlags
{
pub
fn
get
(
&
self
,
endpoint_type
:
&
EndpointType
)
->
bool
{
match
endpoint_type
{
EndpointType
::
Chat
=>
self
.chat_endpoints_enabled
.load
(
Ordering
::
Relaxed
),
EndpointType
::
Completion
=>
self
.cmpl_endpoints_enabled
.load
(
Ordering
::
Relaxed
),
EndpointType
::
Embedding
=>
self
.embeddings_endpoints_enabled
.load
(
Ordering
::
Relaxed
),
EndpointType
::
Responses
=>
self
.responses_endpoints_enabled
.load
(
Ordering
::
Relaxed
),
}
}
pub
fn
set
(
&
self
,
endpoint_type
:
&
EndpointType
,
enabled
:
bool
)
{
match
endpoint_type
{
EndpointType
::
Chat
=>
self
.chat_endpoints_enabled
.store
(
enabled
,
Ordering
::
Relaxed
),
EndpointType
::
Completion
=>
self
.cmpl_endpoints_enabled
.store
(
enabled
,
Ordering
::
Relaxed
),
EndpointType
::
Embedding
=>
self
.embeddings_endpoints_enabled
.store
(
enabled
,
Ordering
::
Relaxed
),
EndpointType
::
Responses
=>
self
.responses_endpoints_enabled
.store
(
enabled
,
Ordering
::
Relaxed
),
}
}
}
impl
State
{
...
...
@@ -31,6 +73,12 @@ impl State {
manager
,
metrics
:
Arc
::
new
(
Metrics
::
default
()),
etcd_client
:
None
,
flags
:
StateFlags
{
chat_endpoints_enabled
:
AtomicBool
::
new
(
false
),
cmpl_endpoints_enabled
:
AtomicBool
::
new
(
false
),
embeddings_endpoints_enabled
:
AtomicBool
::
new
(
false
),
responses_endpoints_enabled
:
AtomicBool
::
new
(
false
),
},
}
}
...
...
@@ -39,9 +87,14 @@ impl State {
manager
,
metrics
:
Arc
::
new
(
Metrics
::
default
()),
etcd_client
,
flags
:
StateFlags
{
chat_endpoints_enabled
:
AtomicBool
::
new
(
false
),
cmpl_endpoints_enabled
:
AtomicBool
::
new
(
false
),
embeddings_endpoints_enabled
:
AtomicBool
::
new
(
false
),
responses_endpoints_enabled
:
AtomicBool
::
new
(
false
),
},
}
}
/// Get the Prometheus [`Metrics`] object which tracks request counts and inflight requests
pub
fn
metrics_clone
(
&
self
)
->
Arc
<
Metrics
>
{
self
.metrics
.clone
()
...
...
@@ -87,10 +140,10 @@ pub struct HttpServiceConfig {
// #[builder(default)]
// custom: Vec<axum::Router>
#[builder(default
=
"
tru
e"
)]
#[builder(default
=
"
fals
e"
)]
enable_chat_endpoints
:
bool
,
#[builder(default
=
"
tru
e"
)]
#[builder(default
=
"
fals
e"
)]
enable_cmpl_endpoints
:
bool
,
#[builder(default
=
"true"
)]
...
...
@@ -151,6 +204,15 @@ impl HttpService {
pub
fn
route_docs
(
&
self
)
->
&
[
RouteDoc
]
{
&
self
.route_docs
}
pub
async
fn
enable_model_endpoint
(
&
self
,
endpoint_type
:
EndpointType
,
enable
:
bool
)
{
self
.state.flags
.set
(
&
endpoint_type
,
enable
);
tracing
::
info!
(
"{} endpoints {}"
,
endpoint_type
.as_str
(),
if
enable
{
"enabled"
}
else
{
"disabled"
}
);
}
}
/// Environment variable to set the metrics endpoint path (default: `/metrics`)
...
...
@@ -177,6 +239,19 @@ impl HttpServiceConfigBuilder {
let
model_manager
=
Arc
::
new
(
ModelManager
::
new
());
let
state
=
Arc
::
new
(
State
::
new_with_etcd
(
model_manager
,
config
.etcd_client
));
state
.flags
.set
(
&
EndpointType
::
Chat
,
config
.enable_chat_endpoints
);
state
.flags
.set
(
&
EndpointType
::
Completion
,
config
.enable_cmpl_endpoints
);
state
.flags
.set
(
&
EndpointType
::
Embedding
,
config
.enable_embeddings_endpoints
);
state
.flags
.set
(
&
EndpointType
::
Responses
,
config
.enable_responses_endpoints
);
// enable prometheus metrics
let
registry
=
metrics
::
Registry
::
new
();
state
.metrics_clone
()
.register
(
&
registry
)
?
;
...
...
@@ -192,42 +267,10 @@ impl HttpServiceConfigBuilder {
super
::
health
::
live_check_router
(
state
.clone
(),
var
(
HTTP_SVC_LIVE_PATH_ENV
)
.ok
()),
];
if
config
.enable_chat_endpoints
{
routes
.push
(
super
::
openai
::
chat_completions_router
(
state
.clone
(),
config
.request_template
.clone
(),
// TODO clone()? reference?
var
(
HTTP_SVC_CHAT_PATH_ENV
)
.ok
(),
));
}
if
config
.enable_cmpl_endpoints
{
routes
.push
(
super
::
openai
::
completions_router
(
state
.clone
(),
var
(
HTTP_SVC_CMP_PATH_ENV
)
.ok
(),
));
}
if
config
.enable_embeddings_endpoints
{
routes
.push
(
super
::
openai
::
embeddings_router
(
state
.clone
(),
var
(
HTTP_SVC_EMB_PATH_ENV
)
.ok
(),
));
}
if
config
.enable_responses_endpoints
{
routes
.push
(
super
::
openai
::
responses_router
(
state
.clone
(),
config
.request_template
,
var
(
HTTP_SVC_RESPONSES_PATH_ENV
)
.ok
(),
));
}
// for (route_docs, route) in routes.into_iter().chain(self.routes.into_iter()) {
// router = router.merge(route);
// all_docs.extend(route_docs);
// }
for
(
route_docs
,
route
)
in
routes
.into_iter
()
{
let
endpoint_routes
=
HttpServiceConfigBuilder
::
get_endpoints_router
(
state
.clone
(),
&
config
.request_template
);
routes
.extend
(
endpoint_routes
);
for
(
route_docs
,
route
)
in
routes
{
router
=
router
.merge
(
route
);
all_docs
.extend
(
route_docs
);
}
...
...
@@ -253,4 +296,58 @@ impl HttpServiceConfigBuilder {
self
.etcd_client
=
Some
(
etcd_client
);
self
}
fn
get_endpoints_router
(
state
:
Arc
<
State
>
,
request_template
:
&
Option
<
RequestTemplate
>
,
)
->
Vec
<
(
Vec
<
RouteDoc
>
,
axum
::
Router
)
>
{
let
mut
routes
=
Vec
::
new
();
// Add chat completions route with conditional middleware
let
(
chat_docs
,
chat_route
)
=
super
::
openai
::
chat_completions_router
(
state
.clone
(),
request_template
.clone
(),
var
(
HTTP_SVC_CHAT_PATH_ENV
)
.ok
(),
);
let
(
cmpl_docs
,
cmpl_route
)
=
super
::
openai
::
completions_router
(
state
.clone
(),
var
(
HTTP_SVC_CMP_PATH_ENV
)
.ok
());
let
(
embed_docs
,
embed_route
)
=
super
::
openai
::
embeddings_router
(
state
.clone
(),
var
(
HTTP_SVC_EMB_PATH_ENV
)
.ok
());
let
(
responses_docs
,
responses_route
)
=
super
::
openai
::
responses_router
(
state
.clone
(),
request_template
.clone
(),
var
(
HTTP_SVC_RESPONSES_PATH_ENV
)
.ok
(),
);
let
mut
endpoint_routes
=
HashMap
::
new
();
endpoint_routes
.insert
(
EndpointType
::
Chat
,
(
chat_docs
,
chat_route
));
endpoint_routes
.insert
(
EndpointType
::
Completion
,
(
cmpl_docs
,
cmpl_route
));
endpoint_routes
.insert
(
EndpointType
::
Embedding
,
(
embed_docs
,
embed_route
));
endpoint_routes
.insert
(
EndpointType
::
Responses
,
(
responses_docs
,
responses_route
));
for
endpoint_type
in
EndpointType
::
all
()
{
let
state_route
=
state
.clone
();
if
!
endpoint_routes
.contains_key
(
&
endpoint_type
)
{
tracing
::
debug!
(
"{} endpoints are disabled"
,
endpoint_type
.as_str
());
continue
;
}
let
(
docs
,
route
)
=
endpoint_routes
.get
(
&
endpoint_type
)
.cloned
()
.unwrap
();
let
route
=
route
.route_layer
(
axum
::
middleware
::
from_fn
(
move
|
req
:
axum
::
http
::
Request
<
axum
::
body
::
Body
>
,
next
:
axum
::
middleware
::
Next
|
{
let
state
:
Arc
<
State
>
=
state_route
.clone
();
async
move
{
// Check if the endpoint is enabled
let
enabled
=
state
.flags
.get
(
&
endpoint_type
);
if
enabled
{
Ok
(
next
.run
(
req
)
.await
)
}
else
{
tracing
::
debug!
(
"{} endpoints are disabled"
,
endpoint_type
.as_str
());
Err
(
axum
::
http
::
StatusCode
::
SERVICE_UNAVAILABLE
)
}
}
},
));
routes
.push
((
docs
,
route
));
}
routes
}
}
lib/llm/src/lib.rs
View file @
537759f1
...
...
@@ -14,6 +14,7 @@ pub mod backend;
pub
mod
common
;
pub
mod
disagg_router
;
pub
mod
discovery
;
pub
mod
endpoint_type
;
pub
mod
engines
;
pub
mod
entrypoint
;
pub
mod
gguf
;
...
...
lib/llm/src/model_type.rs
View file @
537759f1
...
...
@@ -41,4 +41,13 @@ impl ModelType {
pub
fn
all
()
->
Vec
<
Self
>
{
vec!
[
Self
::
Chat
,
Self
::
Completion
,
Self
::
Embedding
,
Self
::
Backend
]
}
pub
fn
as_endpoint_type
(
&
self
)
->
crate
::
endpoint_type
::
EndpointType
{
match
self
{
Self
::
Chat
=>
crate
::
endpoint_type
::
EndpointType
::
Chat
,
Self
::
Completion
=>
crate
::
endpoint_type
::
EndpointType
::
Completion
,
Self
::
Embedding
=>
crate
::
endpoint_type
::
EndpointType
::
Embedding
,
Self
::
Backend
=>
panic!
(
"Backend model type does not map to an endpoint type"
),
}
}
}
lib/llm/tests/http-service.rs
View file @
537759f1
...
...
@@ -270,7 +270,12 @@ fn inc_counter(
#[allow(deprecated)]
#[tokio::test]
async
fn
test_http_service
()
{
let
service
=
HttpService
::
builder
()
.port
(
8989
)
.build
()
.unwrap
();
let
service
=
HttpService
::
builder
()
.port
(
8989
)
.enable_chat_endpoints
(
true
)
.enable_cmpl_endpoints
(
true
)
.build
()
.unwrap
();
let
state
=
service
.state_clone
();
let
manager
=
state
.manager
();
...
...
@@ -572,7 +577,12 @@ async fn wait_for_service_ready(port: u16) {
fn
service_with_engines
(
#[default(
8990
)]
port
:
u16
,
)
->
(
HttpService
,
Arc
<
CounterEngine
>
,
Arc
<
AlwaysFailEngine
>
)
{
let
service
=
HttpService
::
builder
()
.port
(
port
)
.build
()
.unwrap
();
let
service
=
HttpService
::
builder
()
.enable_chat_endpoints
(
true
)
.enable_cmpl_endpoints
(
true
)
.port
(
port
)
.build
()
.unwrap
();
let
manager
=
service
.model_manager
();
let
counter
=
Arc
::
new
(
CounterEngine
{});
...
...
@@ -958,7 +968,12 @@ async fn test_generic_byot_client(
#[rstest]
#[tokio::test]
async
fn
test_client_disconnect_cancellation_unary
()
{
let
service
=
HttpService
::
builder
()
.port
(
8993
)
.build
()
.unwrap
();
let
service
=
HttpService
::
builder
()
.enable_chat_endpoints
(
true
)
.enable_cmpl_endpoints
(
true
)
.port
(
8993
)
.build
()
.unwrap
();
let
state
=
service
.state_clone
();
let
manager
=
state
.manager
();
...
...
@@ -1044,7 +1059,12 @@ async fn test_client_disconnect_cancellation_unary() {
async
fn
test_client_disconnect_cancellation_streaming
()
{
dynamo_runtime
::
logging
::
init
();
let
service
=
HttpService
::
builder
()
.port
(
8994
)
.build
()
.unwrap
();
let
service
=
HttpService
::
builder
()
.enable_chat_endpoints
(
true
)
.enable_cmpl_endpoints
(
true
)
.port
(
8994
)
.build
()
.unwrap
();
let
state
=
service
.state_clone
();
let
manager
=
state
.manager
();
...
...
@@ -1137,7 +1157,12 @@ async fn test_request_id_annotation() {
// TODO(ryan): make better fixtures, this is too much to test sometime so simple
dynamo_runtime
::
logging
::
init
();
let
service
=
HttpService
::
builder
()
.port
(
8995
)
.build
()
.unwrap
();
let
service
=
HttpService
::
builder
()
.enable_chat_endpoints
(
true
)
.enable_cmpl_endpoints
(
true
)
.port
(
8995
)
.build
()
.unwrap
();
let
state
=
service
.state_clone
();
let
manager
=
state
.manager
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment