Commit 2cca070c authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(dynamo-run): Batch mode (#142)

```
dynamo-run in=batch:prompts.jsonl out=mistralrs ~/llm_models/Llama-3.2-3B-Instruct/
```

The file has genai format, one entry per line:
```
{"text": "the prompt"}
{"text": ..etc
```

The prompt is evaluated and the output written to `output.jsonl` in the
same folder as the input.

At the end of the run various statistics are printed:
> Ran 5 files in 8s 679ms. Tokens in: 40 (5/s). Tokens out: 346 (43/s)

This is also helpful for pushing load into the system and stressing the
various components. Not intended for performance measurement, it's a
batch inference tool.
parent 5cfcfe61
...@@ -1607,6 +1607,7 @@ dependencies = [ ...@@ -1607,6 +1607,7 @@ dependencies = [
"dynamo-runtime", "dynamo-runtime",
"futures", "futures",
"futures-util", "futures-util",
"humantime",
"netlink-packet-route", "netlink-packet-route",
"rtnetlink", "rtnetlink",
"serde", "serde",
......
# Dynamo service runner # Dynamo service runner
`dynamo-run` is a tool for exploring the dynamo components. `dynamo-run` is a tool for exploring the dynamo components, and an example of how to use them from Rust.
## Setup ## Setup
...@@ -393,9 +393,30 @@ DYN_TOKEN_ECHO_DELAY_MS=1 dynamo-run in=http out=echo_full ...@@ -393,9 +393,30 @@ DYN_TOKEN_ECHO_DELAY_MS=1 dynamo-run in=http out=echo_full
The default delay is 10ms, which produces approximately 100 tokens per second. The default delay is 10ms, which produces approximately 100 tokens per second.
## Batch mode
dynamo-run can take a jsonl file full of prompts and evaluate them all:
```
dynamo-run in=batch:prompts.jsonl out=llamacpp <model>
```
The input file should look like this:
```
{"text": "What is the capital of France?"}
{"text": "What is the capital of Spain?"}
```
Each one is passed as a prompt to the model. The output is written back to the same folder in `output.jsonl`. At the end of the run some statistics are printed.
The output looks like this:
```
{"text":"What is the capital of France?","response":"The capital of France is Paris.","tokens_in":7,"tokens_out":7,"elapsed_ms":1566}
{"text":"What is the capital of Spain?","response":".The capital of Spain is Madrid.","tokens_in":7,"tokens_out":7,"elapsed_ms":855}
```
## Defaults ## Defaults
The input defaults to `in=text`. The input defaults to `in=text`.
The output will default to whatever engine you have compiled in (so depending on `--features`). If all features The output will default to `mistralrs` engine. If not available whatever engine you have compiled in (so depending on `--features`).
are enabled at build time, then the default is currently `out=vllm`.
...@@ -55,6 +55,7 @@ candle-hf-hub = { version = "0.3.3", default-features = false, features = ["toki ...@@ -55,6 +55,7 @@ candle-hf-hub = { version = "0.3.3", default-features = false, features = ["toki
clap = { version = "4.5", features = ["derive", "env"] } clap = { version = "4.5", features = ["derive", "env"] }
dialoguer = { version = "0.11", default-features = false, features = ["editor", "history"] } dialoguer = { version = "0.11", default-features = false, features = ["editor", "history"] }
futures-util = { version = "0.3" } futures-util = { version = "0.3" }
humantime = "2.2.0"
tracing-subscriber = { version = "0.3", features = ["env-filter", "local-time", "json"] } tracing-subscriber = { version = "0.3", features = ["env-filter", "local-time", "json"] }
[target.x86_64-unknown-linux-gnu.dependencies] [target.x86_64-unknown-linux-gnu.dependencies]
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
pub mod batch;
mod common;
pub mod endpoint; pub mod endpoint;
pub mod http; pub mod http;
pub mod text; pub mod text;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context as _;
use dynamo_llm::model_card::model::ModelDeploymentCard;
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::types::openai::chat_completions::{
NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
};
use dynamo_runtime::{pipeline::Context, runtime::CancellationToken, Runtime};
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use std::cmp;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
use crate::input::common;
use crate::EngineConfig;
/// Max tokens in each response.
/// TODO: For batch mode this should be the full context size of the model
const MAX_TOKENS: u32 = 8192;
const OUTPUT_FILENAME: &str = "output.jsonl";
const DUMMY_MODEL_NAME: &str = "dynamo-run-batch";
#[derive(Serialize, Deserialize, Default, Debug)]
struct Entry {
// The input files only have this
text: String,
response: Option<String>,
#[serde(default)]
tokens_in: usize,
#[serde(default)]
tokens_out: usize,
#[serde(default)]
elapsed_ms: usize,
}
pub async fn run(
runtime: Runtime,
cancel_token: CancellationToken,
maybe_card: Option<ModelDeploymentCard>,
input_jsonl: PathBuf,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
// Check if the path exists and is a directory
if !input_jsonl.exists() || !input_jsonl.is_file() {
anyhow::bail!(
"Missing or not a file: {}. Should be a JSON Lines file.",
input_jsonl.display()
);
}
let (_service_name, engine, _inspect_template) =
common::prepare_engine(runtime.clone(), engine_config).await?;
let pre_processor = if let Some(card) = maybe_card {
Some(OpenAIPreprocessor::new(card).await?)
} else {
None
};
let (all_finish_tx, all_finish_rx) = tokio::sync::oneshot::channel();
let (done_entries_tx, done_entries_rx) = tokio::sync::mpsc::channel(64);
let dw_cancel_token = cancel_token.clone();
let mut output_file = input_jsonl.clone();
output_file.set_file_name(OUTPUT_FILENAME);
tokio::spawn(async move {
if let Err(err) = output_writer(
dw_cancel_token,
done_entries_rx,
&output_file,
all_finish_tx,
)
.await
{
tracing::error!(%err, "Failed writing output to {}", output_file.display());
}
});
let tokens_in = Arc::new(AtomicU64::new(0));
let tokens_out = Arc::new(AtomicU64::new(0));
let mut handles = vec![];
let mut num_entries = 0;
let input_file = tokio::fs::File::open(&input_jsonl)
.await
.with_context(|| input_jsonl.display().to_string())?;
let buffered_input = tokio::io::BufReader::new(input_file);
tracing::info!("Timer start.");
let start = Instant::now();
let mut lines = buffered_input.lines();
while let Ok(Some(line)) = lines.next_line().await {
if cancel_token.is_cancelled() {
break;
}
if line.is_empty() {
continue;
}
let request_id = num_entries;
num_entries += 1;
let mut entry: Entry = match serde_json::from_str(&line) {
Ok(entry) => entry,
Err(err) => {
anyhow::bail!("Error parsing entry: '{line}'. {err}");
}
};
let engine = engine.clone();
let pre_processor = pre_processor.clone();
let tokens_in = tokens_in.clone();
let tokens_out = tokens_out.clone();
let done_entries_tx = done_entries_tx.clone();
let handle = tokio::spawn(async move {
let local_start = Instant::now();
let response = match evaluate(request_id, engine, &entry.text).await {
Ok(r) => r,
Err(err) => {
tracing::error!(%err, entry.text, "Failed evaluating prompt");
return;
}
};
let local_elapsed = Instant::now() - local_start;
entry.elapsed_ms = local_elapsed.as_millis() as usize;
if let Some(pre) = pre_processor {
// Note this does not include the prompt template. Probably TODO
entry.tokens_in = match pre.tokenize(&entry.text) {
Ok(encoding) => encoding.token_ids.len(),
Err(err) => {
tracing::warn!(%err, entry.text, "Failed tokenizing prompt");
0
}
};
entry.tokens_out = match pre.tokenize(&response) {
Ok(encoding) => encoding.token_ids.len(),
Err(err) => {
tracing::warn!(%err, response, "Failed tokenizing response");
0
}
};
tokens_in.fetch_add(entry.tokens_in as u64, Ordering::Relaxed);
tokens_out.fetch_add(entry.tokens_out as u64, Ordering::Relaxed);
}
entry.response = Some(response);
let _ = done_entries_tx.send(entry).await;
});
handles.push(handle);
}
tokio::select! {
_ = cancel_token.cancelled() => {
// Don't print stats
return Ok(());
}
_ = futures::future::join_all(handles) => {
}
_ = all_finish_rx => {
}
}
let elapsed = Instant::now() - start;
let elapsed_clean = Duration::from_millis(elapsed.as_millis() as u64);
let tokens_in = Arc::into_inner(tokens_in).unwrap().into_inner();
let tokens_out = Arc::into_inner(tokens_out).unwrap().into_inner();
tokio::time::sleep(Duration::from_millis(1)).await; // Let output_writer finish stdout write
tracing::info!(
"Ran {} files in {}. Tokens in: {} ({}/s). Tokens out: {} ({}/s)",
num_entries,
humantime::format_duration(elapsed_clean),
tokens_in,
tokens_in / cmp::max(elapsed.as_secs(), 1),
tokens_out,
tokens_out / cmp::max(elapsed.as_secs(), 1),
);
Ok(())
}
// Run a single prompt through the engine
async fn evaluate(
_request_id: usize,
engine: OpenAIChatCompletionsStreamingEngine,
prompt: &str,
) -> anyhow::Result<String> {
let user_message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
prompt.to_string(),
),
name: None,
},
);
let inner = async_openai::types::CreateChatCompletionRequestArgs::default()
.messages(vec![user_message])
.model(DUMMY_MODEL_NAME)
.stream(true)
.max_tokens(MAX_TOKENS)
.build()?;
let req = NvCreateChatCompletionRequest { inner, nvext: None };
let mut stream = engine.generate(Context::new(req)).await?;
let mut output = String::new();
while let Some(item) = stream.next().await {
match (item.data.as_ref(), item.event.as_deref()) {
(Some(data), _) => {
// Normal case
let entry = data.inner.choices.first();
let chat_comp = entry.as_ref().unwrap();
if let Some(c) = &chat_comp.delta.content {
output += c;
}
if chat_comp.finish_reason.is_some() {
tracing::trace!("finish reason: {:?}", chat_comp.finish_reason.unwrap());
break;
}
}
(None, Some("error")) => {
// There's only one error but we loop in case that changes
for err in item.comment.unwrap_or_default() {
tracing::error!("Engine error: {err}");
}
}
(None, Some(annotation)) => {
tracing::debug!("Annotation. {annotation}: {:?}", item.comment);
}
_ => {
unreachable!("Event from engine with no data, no error, no annotation.");
}
}
}
Ok(output)
}
async fn output_writer(
cancel_token: CancellationToken,
mut entries_rx: tokio::sync::mpsc::Receiver<Entry>,
output_file: &Path,
all_finish_tx: tokio::sync::oneshot::Sender<()>,
) -> anyhow::Result<()> {
let mut num_completed = 0;
let mut f = tokio::fs::File::create(output_file).await?;
loop {
let maybe_entry = tokio::select! {
_ = cancel_token.cancelled() => {
break;
}
entry = entries_rx.recv() => {
entry
}
};
let Some(entry) = maybe_entry else {
let _ = all_finish_tx.send(());
break;
};
let mut s = serde_json::to_string(&entry)?;
s.push('\n');
f.write_all(s.as_bytes()).await?;
num_completed += 1;
// TODO: Progress bar. We'd have to count the lines in the input first,
// and the input maybe be large
tracing::info!("Saved {num_completed}");
}
Ok(())
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::EngineConfig;
use dynamo_llm::{
backend::Backend,
preprocessor::OpenAIPreprocessor,
types::{
openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
OpenAIChatCompletionsStreamingEngine,
},
Annotated,
},
};
use dynamo_runtime::{
pipeline::{ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source},
DistributedRuntime, Runtime,
};
use std::sync::Arc;
/// Turns an EngineConfig into an OpenAIChatCompletionsStreamingEngine.
pub async fn prepare_engine(
runtime: Runtime,
engine_config: EngineConfig,
) -> anyhow::Result<(String, OpenAIChatCompletionsStreamingEngine, bool)> {
match engine_config {
EngineConfig::Dynamic(endpoint_id) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let endpoint = distributed_runtime
.namespace(endpoint_id.namespace)?
.component(endpoint_id.component)?
.endpoint(endpoint_id.name);
let client = endpoint.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>().await?;
tracing::info!("Waiting for remote model..");
client.wait_for_endpoints().await?;
tracing::info!("Model discovered");
// The service_name isn't used for text chat outside of logs,
// so use the path. That avoids having to listen on etcd for model registration.
let service_name = endpoint.subject();
Ok((service_name, Arc::new(client), false))
}
EngineConfig::StaticFull {
service_name,
engine,
} => {
tracing::debug!("Model: {service_name}");
Ok((service_name, engine, false))
}
EngineConfig::StaticCore {
service_name,
engine: inner_engine,
card,
} => {
let frontend = ServiceFrontend::<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone())
.await?
.into_operator();
let backend = Backend::from_mdc(*card.clone()).await?.into_operator();
let engine = ServiceBackend::from_engine(inner_engine);
let pipeline = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(engine)?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
tracing::debug!("Model: {service_name} with pre-processing");
Ok((service_name, pipeline, true))
}
EngineConfig::None => unreachable!(),
}
}
...@@ -13,28 +13,14 @@ ...@@ -13,28 +13,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use dynamo_llm::{ use dynamo_llm::types::openai::chat_completions::{
backend::Backend, NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
preprocessor::OpenAIPreprocessor,
types::{
openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
OpenAIChatCompletionsStreamingEngine,
},
Annotated,
},
};
use dynamo_runtime::{
pipeline::{Context, ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source},
runtime::CancellationToken,
DistributedRuntime, Runtime,
}; };
use dynamo_runtime::{pipeline::Context, runtime::CancellationToken, Runtime};
use futures::StreamExt; use futures::StreamExt;
use std::{ use std::io::{ErrorKind, Write};
io::{ErrorKind, Write},
sync::Arc,
};
use crate::input::common;
use crate::EngineConfig; use crate::EngineConfig;
/// Max response tokens for each single query. Must be less than model context size. /// Max response tokens for each single query. Must be less than model context size.
...@@ -50,60 +36,7 @@ pub async fn run( ...@@ -50,60 +36,7 @@ pub async fn run(
String, String,
OpenAIChatCompletionsStreamingEngine, OpenAIChatCompletionsStreamingEngine,
bool, bool,
) = match engine_config { ) = common::prepare_engine(runtime.clone(), engine_config).await?;
EngineConfig::Dynamic(endpoint_id) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let endpoint = distributed_runtime
.namespace(endpoint_id.namespace)?
.component(endpoint_id.component)?
.endpoint(endpoint_id.name);
let client = endpoint.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>().await?;
tracing::info!("Waiting for remote model..");
client.wait_for_endpoints().await?;
tracing::info!("Model discovered");
// The service_name isn't used for text chat outside of logs,
// so use the path. That avoids having to listen on etcd for model registration.
let service_name = endpoint.subject();
(service_name, Arc::new(client), false)
}
EngineConfig::StaticFull {
service_name,
engine,
} => {
tracing::debug!("Model: {service_name}");
(service_name, engine, false)
}
EngineConfig::StaticCore {
service_name,
engine: inner_engine,
card,
} => {
let frontend = ServiceFrontend::<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone())
.await?
.into_operator();
let backend = Backend::from_mdc(*card.clone()).await?.into_operator();
let engine = ServiceBackend::from_engine(inner_engine);
let pipeline = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(engine)?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
tracing::debug!("Model: {service_name} with pre-processing");
(service_name, pipeline, true)
}
EngineConfig::None => unreachable!(),
};
main_loop( main_loop(
cancel_token, cancel_token,
&service_name, &service_name,
......
...@@ -317,7 +317,7 @@ pub async fn run( ...@@ -317,7 +317,7 @@ pub async fn run(
if !model_path.is_file() { if !model_path.is_file() {
anyhow::bail!("--model-path should refer to a GGUF file. llama_cpp does not support safetensors."); anyhow::bail!("--model-path should refer to a GGUF file. llama_cpp does not support safetensors.");
} }
let Some(card) = maybe_card else { let Some(card) = maybe_card.clone() else {
anyhow::bail!( anyhow::bail!(
"Pass --model-config so we can find the tokenizer, should be an HF checkout." "Pass --model-config so we can find the tokenizer, should be an HF checkout."
); );
...@@ -402,6 +402,16 @@ pub async fn run( ...@@ -402,6 +402,16 @@ pub async fn run(
) )
.await?; .await?;
} }
Input::Batch(path) => {
crate::input::batch::run(
runtime.clone(),
cancel_token.clone(),
maybe_card,
path,
engine_config,
)
.await?;
}
Input::Endpoint(path) => { Input::Endpoint(path) => {
crate::input::endpoint::run(runtime.clone(), path, engine_config).await?; crate::input::endpoint::run(runtime.clone(), path, engine_config).await?;
} }
......
...@@ -32,7 +32,7 @@ Example: ...@@ -32,7 +32,7 @@ Example:
const ZMQ_SOCKET_PREFIX: &str = "dyn"; const ZMQ_SOCKET_PREFIX: &str = "dyn";
const USAGE: &str = "USAGE: dynamo-run in=[http|text|dyn://<path>|none] out=[mistralrs|sglang|llamacpp|vllm|trtllm|echo_full|echo_core|pystr:<engine.py>|pytok:<engine.py>] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0]"; const USAGE: &str = "USAGE: dynamo-run in=[http|text|dyn://<path>|batch:<folder>|none] out=[mistralrs|sglang|llamacpp|vllm|trtllm|echo_full|echo_core|pystr:<engine.py>|pytok:<engine.py>] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0]";
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
logging::init(); logging::init();
......
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{fmt, io::IsTerminal as _}; use std::{fmt, io::IsTerminal as _, path::PathBuf};
use crate::ENDPOINT_SCHEME; use crate::ENDPOINT_SCHEME;
const BATCH_PREFIX: &str = "batch:";
#[derive(PartialEq)] #[derive(PartialEq)]
pub enum Input { pub enum Input {
/// Run an OpenAI compatible HTTP server /// Run an OpenAI compatible HTTP server
...@@ -31,6 +33,9 @@ pub enum Input { ...@@ -31,6 +33,9 @@ pub enum Input {
/// Pull requests from a namespace/component/endpoint path. /// Pull requests from a namespace/component/endpoint path.
Endpoint(String), Endpoint(String),
/// Batch mode. Run all the prompts, write the outputs, exit.
Batch(PathBuf),
/// Start the engine but don't provide any way to talk to it. /// Start the engine but don't provide any way to talk to it.
/// For multi-node sglang, where the engine connects directly /// For multi-node sglang, where the engine connects directly
/// to the co-ordinator via torch distributed / nccl. /// to the co-ordinator via torch distributed / nccl.
...@@ -50,6 +55,10 @@ impl TryFrom<&str> for Input { ...@@ -50,6 +55,10 @@ impl TryFrom<&str> for Input {
let path = endpoint_path.strip_prefix(ENDPOINT_SCHEME).unwrap(); let path = endpoint_path.strip_prefix(ENDPOINT_SCHEME).unwrap();
Ok(Input::Endpoint(path.to_string())) Ok(Input::Endpoint(path.to_string()))
} }
batch_patch if batch_patch.starts_with(BATCH_PREFIX) => {
let path = batch_patch.strip_prefix(BATCH_PREFIX).unwrap();
Ok(Input::Batch(PathBuf::from(path)))
}
e => Err(anyhow::anyhow!("Invalid in= option '{e}'")), e => Err(anyhow::anyhow!("Invalid in= option '{e}'")),
} }
} }
...@@ -62,6 +71,7 @@ impl fmt::Display for Input { ...@@ -62,6 +71,7 @@ impl fmt::Display for Input {
Input::Text => "text", Input::Text => "text",
Input::Stdin => "stdin", Input::Stdin => "stdin",
Input::Endpoint(path) => path, Input::Endpoint(path) => path,
Input::Batch(path) => &path.display().to_string(),
Input::None => "none", Input::None => "none",
}; };
write!(f, "{s}") write!(f, "{s}")
......
...@@ -34,6 +34,7 @@ use tracing; ...@@ -34,6 +34,7 @@ use tracing;
use crate::model_card::model::{ModelDeploymentCard, ModelInfo, TokenizerKind}; use crate::model_card::model::{ModelDeploymentCard, ModelInfo, TokenizerKind};
use crate::preprocessor::prompt::OAIChatLikeRequest; use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::tokenizers::Encoding;
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream}; use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::{ use dynamo_runtime::pipeline::{
...@@ -88,6 +89,11 @@ impl OpenAIPreprocessor { ...@@ -88,6 +89,11 @@ impl OpenAIPreprocessor {
})) }))
} }
/// Encode a string to it's tokens
pub fn tokenize(&self, s: &str) -> anyhow::Result<Encoding> {
self.tokenizer.encode(s)
}
/// Translate a [`NvCreateChatCompletionRequest`] request to a common completion request. /// Translate a [`NvCreateChatCompletionRequest`] request to a common completion request.
/// Returns both the common completion request and a hashmap of annotations. /// Returns both the common completion request and a hashmap of annotations.
/// ///
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment