Commit 72064d84 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: tio support preprocessor (#265)

Add backend type `EngineConfig::StaticCore` that wraps the engine in a preprocessor (prompt templating and tokenization).

Add example engine `echo_core` (`out=echo_core`) which takes and returns tokens. A nice side effect is that it echos the full prompt template with system prompt, whereas `echo_full` echos only user prompt.

![image](https://github.com/user-attachments/assets/27ec0a7b-a27d-4e69-96ea-1ffa0822ea90)
parent c06b95ff
......@@ -14,9 +14,18 @@
// limitations under the License.
use triton_distributed_llm::http::service::discovery::ModelEntry;
use triton_distributed_runtime::{
pipeline::network::Ingress, protocols::Endpoint, DistributedRuntime, Runtime,
use triton_distributed_llm::{
backend::Backend,
preprocessor::OpenAIPreprocessor,
types::{
openai::chat_completions::{ChatCompletionRequest, ChatCompletionResponseDelta},
Annotated,
},
};
use triton_distributed_runtime::pipeline::{
network::Ingress, ManyOut, Operator, SegmentSource, ServiceBackend, SingleIn, Source,
};
use triton_distributed_runtime::{protocols::Endpoint, DistributedRuntime, Runtime};
use crate::{EngineConfig, ENDPOINT_SCHEME};
......@@ -28,28 +37,60 @@ pub async fn run(
// This will attempt to connect to NATS and etcd
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
match engine_config {
EngineConfig::StaticFull {
service_name,
engine,
} => {
let cancel_token = runtime.primary_token().clone();
let elements: Vec<&str> = path.split('/').collect();
if elements.len() != 3 {
anyhow::bail!("An endpoint URL must have format {ENDPOINT_SCHEME}namespace/component/endpoint");
anyhow::bail!(
"An endpoint URL must have format {ENDPOINT_SCHEME}namespace/component/endpoint"
);
}
// Register with etcd
let endpoint = Endpoint {
namespace: elements[0].to_string(),
component: elements[1].to_string(),
name: elements[2].to_string(),
};
let etcd_client = distributed.etcd_client();
let (ingress, service_name) = match engine_config {
EngineConfig::StaticFull {
service_name,
engine,
} => (Ingress::for_engine(engine)?, service_name),
EngineConfig::StaticCore {
service_name,
engine: inner_engine,
card,
} => {
let frontend = SegmentSource::<
SingleIn<ChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone())
.await?
.into_operator();
let backend = Backend::from_mdc(*card.clone()).await?.into_operator();
let engine = ServiceBackend::from_engine(inner_engine);
let pipeline = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(engine)?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
(Ingress::for_pipeline(pipeline)?, service_name)
}
EngineConfig::Dynamic(_) => {
anyhow::bail!("Cannot use endpoint for both in and out");
}
};
let model_registration = ModelEntry {
name: service_name.to_string(),
endpoint,
};
let etcd_client = distributed.etcd_client();
etcd_client
.kv_create(
path.clone(),
......@@ -58,8 +99,6 @@ pub async fn run(
)
.await?;
// Start the model
let ingress = Ingress::for_engine(engine)?;
let rt_fut = distributed
.namespace(elements[0])?
.component(elements[1])?
......@@ -70,6 +109,7 @@ pub async fn run(
.endpoint_builder()
.handler(ingress)
.start();
tokio::select! {
_ = rt_fut => {
tracing::debug!("Endpoint ingress ended");
......@@ -78,9 +118,4 @@ pub async fn run(
}
}
Ok(())
}
EngineConfig::Dynamic(_) => {
anyhow::bail!("Cannot use endpoint for both in and out");
}
}
}
......@@ -15,8 +15,19 @@
use std::sync::Arc;
use triton_distributed_llm::http::service::{discovery, service_v2};
use triton_distributed_runtime::{DistributedRuntime, Runtime};
use triton_distributed_llm::{
backend::Backend,
http::service::{discovery, service_v2},
preprocessor::OpenAIPreprocessor,
types::{
openai::chat_completions::{ChatCompletionRequest, ChatCompletionResponseDelta},
Annotated,
},
};
use triton_distributed_runtime::{
pipeline::{ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source},
DistributedRuntime, Runtime,
};
use crate::EngineConfig;
......@@ -55,6 +66,32 @@ pub async fn run(
.model_manager()
.add_chat_completions_model(&service_name, engine)?;
}
EngineConfig::StaticCore {
service_name,
engine: inner_engine,
card,
} => {
let frontend = ServiceFrontend::<
SingleIn<ChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone())
.await?
.into_operator();
let backend = Backend::from_mdc(*card.clone()).await?.into_operator();
let engine = ServiceBackend::from_engine(inner_engine);
let pipeline = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(engine)?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
http_service
.model_manager()
.add_chat_completions_model(&service_name, pipeline)?;
}
}
http_service.run(runtime.primary_token()).await
}
......@@ -19,12 +19,21 @@ use std::{
sync::Arc,
};
use triton_distributed_llm::{
backend::Backend,
preprocessor::OpenAIPreprocessor,
protocols::openai::chat_completions::MessageRole,
types::openai::chat_completions::{
ChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
types::{
openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta,
OpenAIChatCompletionsStreamingEngine,
},
Annotated,
},
};
use triton_distributed_runtime::{pipeline::Context, runtime::CancellationToken};
use triton_distributed_runtime::{
pipeline::{Context, ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source},
runtime::CancellationToken,
};
use crate::EngineConfig;
......@@ -57,6 +66,32 @@ pub async fn run(
tracing::info!("Model: {service_name}");
(service_name, engine, false)
}
EngineConfig::StaticCore {
service_name,
engine: inner_engine,
card,
} => {
let frontend = ServiceFrontend::<
SingleIn<ChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone())
.await?
.into_operator();
let backend = Backend::from_mdc(*card.clone()).await?.into_operator();
let engine = ServiceBackend::from_engine(inner_engine);
let pipeline = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(engine)?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
tracing::info!("Model: {service_name} with pre-processing");
(service_name, pipeline, true)
}
};
main_loop(cancel_token, &service_name, engine, inspect_template).await
}
......
......@@ -15,11 +15,16 @@
use std::path::PathBuf;
use triton_distributed_llm::types::{
use triton_distributed_llm::{
backend::ExecutionContext,
model_card::model::ModelDeploymentCard,
types::{
openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta, OpenAIChatCompletionsStreamingEngine,
ChatCompletionRequest, ChatCompletionResponseDelta,
OpenAIChatCompletionsStreamingEngine,
},
Annotated,
},
};
use triton_distributed_runtime::{component::Client, DistributedRuntime};
......@@ -65,6 +70,13 @@ pub enum EngineConfig {
service_name: String,
engine: OpenAIChatCompletionsStreamingEngine,
},
/// A core engine expects to be wrapped with pre/post processors that handle tokenization.
StaticCore {
service_name: String,
engine: ExecutionContext,
card: Box<ModelDeploymentCard>,
},
}
pub async fn run(
......@@ -87,6 +99,15 @@ pub async fn run(
.and_then(|p| p.iter().last())
.map(|n| n.to_string_lossy().into_owned())
});
// If model path is a directory we can build a model deployment card from it
let maybe_card = match &model_path {
Some(model_path) if model_path.is_dir() => {
ModelDeploymentCard::from_local_path(model_path, model_name.as_deref())
.await
.ok()
}
Some(_) | None => None,
};
// Create the engine matching `out`
let engine_config = match out_opt {
......@@ -101,6 +122,19 @@ pub async fn run(
engine: output::echo_full::make_engine_full(),
}
}
Output::EchoCore => {
let Some(mut card) = maybe_card.clone() else {
anyhow::bail!(
"out=echo_core need to find the tokenizer. Pass flag --model-path <path>"
);
};
card.requires_preprocessing = true;
EngineConfig::StaticCore {
service_name: card.service_name.clone(),
engine: output::echo_core::make_engine_core(),
card: Box::new(card),
}
}
Output::Endpoint(path) => {
let elements: Vec<&str> = path.split('/').collect();
if elements.len() != 3 {
......
......@@ -59,6 +59,9 @@ pub enum Output {
/// Accept un-preprocessed requests, echo the prompt back as the response
EchoFull,
/// Accept preprocessed requests, echo the tokens back as the response
EchoCore,
/// Publish requests to a namespace/component/endpoint path.
Endpoint(String),
......@@ -76,6 +79,7 @@ impl TryFrom<&str> for Output {
"mistralrs" => Ok(Output::MistralRs),
"echo_full" => Ok(Output::EchoFull),
"echo_core" => Ok(Output::EchoCore),
endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => {
let path = endpoint_path.strip_prefix(ENDPOINT_SCHEME).unwrap();
......@@ -94,6 +98,7 @@ impl fmt::Display for Output {
Output::MistralRs => "mistralrs",
Output::EchoFull => "echo_full",
Output::EchoCore => "echo_core",
Output::Endpoint(path) => path,
};
......
......@@ -13,4 +13,5 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod echo_core;
pub mod echo_full;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{sync::Arc, time::Duration};
use async_stream::stream;
use async_trait::async_trait;
use triton_distributed_llm::backend::ExecutionContext;
use triton_distributed_llm::preprocessor::BackendInput;
use triton_distributed_llm::protocols::common::llm_backend::LLMEngineOutput;
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated;
/// How long to sleep between echoed tokens.
/// 50ms gives us 20 tok/s.
const TOKEN_ECHO_DELAY: Duration = Duration::from_millis(50);
/// Engine that accepts pre-processed requests and echos the tokens back as the response
/// The response will include the full prompt template.
/// Useful for testing pre-processing.
struct EchoEngineCore {}
pub fn make_engine_core() -> ExecutionContext {
Arc::new(EchoEngineCore {})
}
#[async_trait]
impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for EchoEngineCore
{
async fn generate(
&self,
incoming_request: SingleIn<BackendInput>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
let (request, context) = incoming_request.into_parts();
let ctx = context.context();
let output = stream! {
for tok in request.token_ids {
tokio::time::sleep(TOKEN_ECHO_DELAY).await;
yield delta_core(tok);
}
yield Annotated::from_data(LLMEngineOutput::stop());
};
Ok(ResponseStream::new(Box::pin(output), ctx))
}
}
fn delta_core(tok: u32) -> Annotated<LLMEngineOutput> {
let delta = LLMEngineOutput {
token_ids: vec![tok],
tokens: None,
text: None,
cum_log_probs: None,
log_probs: None,
finish_reason: None,
};
Annotated::from_data(delta)
}
......@@ -40,7 +40,7 @@ impl ModelDeploymentCard {
/// - Required model files are missing or invalid
pub async fn from_local_path(
local_root_dir: impl AsRef<Path>,
model_name: Option<String>,
model_name: Option<&str>,
) -> anyhow::Result<Self> {
let local_root_dir = local_root_dir.as_ref();
check_valid_local_repo_path(local_root_dir)?;
......@@ -53,10 +53,9 @@ impl ModelDeploymentCard {
local_root_dir
.file_name()
.and_then(|n| n.to_str())
.ok_or_else(|| anyhow::anyhow!("Invalid model directory name"))?
.to_string(),
.ok_or_else(|| anyhow::anyhow!("Invalid model directory name"))?,
);
Self::from_repo(&repo_id, &model_name).await
Self::from_repo(&repo_id, model_name).await
}
/// TODO: This will be implemented after nova-hub is integrated with the model-card
......
......@@ -47,7 +47,7 @@ async fn make_mdc_from_repo(
//TODO: remove this once we have nim-hub support. See the NOTE above.
let downloaded_path = maybe_download_model(local_path, hf_repo, hf_revision).await;
let display_name = format!("{}--{}", hf_repo, hf_revision);
let mut mdc = ModelDeploymentCard::from_local_path(downloaded_path, Some(display_name))
let mut mdc = ModelDeploymentCard::from_local_path(downloaded_path, Some(&display_name))
.await
.unwrap();
mdc.prompt_context = mixins;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment