Commit c06b95ff authored by Ryan McCormick's avatar Ryan McCormick Committed by GitHub
Browse files

ci: Add rust checks to missing directories (#239)


Signed-off-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 5f1af25a
...@@ -26,9 +26,12 @@ on: ...@@ -26,9 +26,12 @@ on:
branches: branches:
- main - main
paths: paths:
- pre-merge-rust.yml
- 'lib/runtime/**' - 'lib/runtime/**'
- 'lib/llm/**' - 'lib/llm/**'
- 'applications/llm/tio/**' - 'lib/bindings/**'
- 'applications/llm/**'
- 'examples/rust/**'
- '**.rs' - '**.rs'
- 'Cargo.toml' - 'Cargo.toml'
- 'Cargo.lock' - 'Cargo.lock'
...@@ -36,6 +39,8 @@ on: ...@@ -36,6 +39,8 @@ on:
jobs: jobs:
pre-merge-rust: pre-merge-rust:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy:
matrix: { dir: ['lib/runtime', 'lib/llm', 'lib/bindings/c', 'lib/bindings/python', 'applications/llm/tio', 'applications/llm/count', 'examples/rust'] }
permissions: permissions:
contents: read contents: read
steps: steps:
...@@ -64,27 +69,30 @@ jobs: ...@@ -64,27 +69,30 @@ jobs:
echo "$HOME/.cargo/bin" >> $GITHUB_PATH echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Set up Rust Toolchain Components - name: Set up Rust Toolchain Components
run: rustup component add rustfmt clippy run: rustup component add rustfmt clippy
- name: Run Cargo Check on runtime - name: Run Cargo Check
working-directory: lib/runtime working-directory: ${{ matrix.dir }}
run: cargo check --locked
- name: Run Cargo Check on tio
working-directory: applications/llm/tio
run: cargo check --locked run: cargo check --locked
timeout-minutes: 5
- name: Verify Code Formatting - name: Verify Code Formatting
working-directory: lib/runtime working-directory: ${{ matrix.dir }}
run: cargo fmt -- --check run: cargo fmt -- --check
- name: Run Clippy Checks on runtime - name: Run Clippy Checks
working-directory: lib/runtime working-directory: ${{ matrix.dir }}
run: cargo clippy --no-deps --all-targets -- -D warnings
- name: Run Clippy Checks on tio
working-directory: applications/llm/tio
run: cargo clippy --no-deps --all-targets -- -D warnings run: cargo clippy --no-deps --all-targets -- -D warnings
- name: Install and Run cargo-deny - name: Install and Run cargo-deny
working-directory: lib/runtime working-directory: ${{ matrix.dir }}
# FIXME: Skip this step for failing dirs until license errors fixed
if: |
matrix.dir != 'examples/rust' &&
matrix.dir != 'applications/llm/count' &&
matrix.dir != 'applications/llm/tio' &&
matrix.dir != 'lib/llm' &&
matrix.dir != 'lib/bindings/c' &&
matrix.dir != 'lib/bindings/python'
run: | run: |
cargo-deny --version || cargo install cargo-deny@0.16.4 cargo-deny --version || cargo install cargo-deny@0.16.4
cargo-deny check --hide-inclusion-graph licenses cargo-deny check --hide-inclusion-graph licenses
timeout-minutes: 5 timeout-minutes: 5
- name: Run Unit Tests - name: Run Unit Tests
working-directory: lib/runtime working-directory: ${{ matrix.dir }}
run: cargo test --locked --all-targets run: cargo test --locked --all-targets
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use triton_distributed_llm::http::service::discovery::ModelEntry;
use triton_distributed_runtime::{ use triton_distributed_runtime::{
pipeline::network::Ingress, protocols::Endpoint, DistributedRuntime, Runtime, pipeline::network::Ingress, protocols::Endpoint, DistributedRuntime, Runtime,
}; };
use triton_distributed_llm::http::service::discovery::ModelEntry;
use crate::{EngineConfig, ENDPOINT_SCHEME}; use crate::{EngineConfig, ENDPOINT_SCHEME};
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
use std::sync::Arc; use std::sync::Arc;
use triton_distributed_runtime::{DistributedRuntime, Runtime};
use triton_distributed_llm::http::service::{discovery, service_v2}; use triton_distributed_llm::http::service::{discovery, service_v2};
use triton_distributed_runtime::{DistributedRuntime, Runtime};
use crate::EngineConfig; use crate::EngineConfig;
......
...@@ -18,13 +18,13 @@ use std::{ ...@@ -18,13 +18,13 @@ use std::{
io::{ErrorKind, Read, Write}, io::{ErrorKind, Read, Write},
sync::Arc, sync::Arc,
}; };
use triton_distributed_runtime::{pipeline::Context, runtime::CancellationToken};
use triton_distributed_llm::{ use triton_distributed_llm::{
protocols::openai::chat_completions::MessageRole, protocols::openai::chat_completions::MessageRole,
types::openai::chat_completions::{ types::openai::chat_completions::{
ChatCompletionRequest, OpenAIChatCompletionsStreamingEngine, ChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
}, },
}; };
use triton_distributed_runtime::{pipeline::Context, runtime::CancellationToken};
use crate::EngineConfig; use crate::EngineConfig;
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
use std::path::PathBuf; use std::path::PathBuf;
use triton_distributed_runtime::{component::Client, DistributedRuntime};
use triton_distributed_llm::types::{ use triton_distributed_llm::types::{
openai::chat_completions::{ openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta, OpenAIChatCompletionsStreamingEngine, ChatCompletionRequest, ChatCompletionResponseDelta, OpenAIChatCompletionsStreamingEngine,
}, },
Annotated, Annotated,
}; };
use triton_distributed_runtime::{component::Client, DistributedRuntime};
mod input; mod input;
mod opt; mod opt;
...@@ -138,7 +138,8 @@ pub async fn run( ...@@ -138,7 +138,8 @@ pub async fn run(
}; };
EngineConfig::StaticFull { EngineConfig::StaticFull {
service_name: model_name, service_name: model_name,
engine: triton_distributed_llm::engines::mistralrs::make_engine(&model_path).await?, engine: triton_distributed_llm::engines::mistralrs::make_engine(&model_path)
.await?,
} }
} }
}; };
......
...@@ -18,14 +18,14 @@ use std::{sync::Arc, time::Duration}; ...@@ -18,14 +18,14 @@ use std::{sync::Arc, time::Duration};
use async_stream::stream; use async_stream::stream;
use async_trait::async_trait; use async_trait::async_trait;
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated;
use triton_distributed_llm::protocols::openai::chat_completions::FinishReason; use triton_distributed_llm::protocols::openai::chat_completions::FinishReason;
use triton_distributed_llm::protocols::openai::chat_completions::{ use triton_distributed_llm::protocols::openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta, Content, ChatCompletionRequest, ChatCompletionResponseDelta, Content,
}; };
use triton_distributed_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine; use triton_distributed_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated;
/// How long to sleep between echoed tokens. /// How long to sleep between echoed tokens.
/// 50ms gives us 20 tok/s. /// 50ms gives us 20 tok/s.
......
...@@ -1257,7 +1257,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" ...@@ -1257,7 +1257,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
name = "hello_world" name = "hello_world"
version = "0.2.0" version = "0.2.0"
dependencies = [ dependencies = [
"triton-distributed", "triton-distributed-runtime",
] ]
[[package]] [[package]]
...@@ -1285,8 +1285,8 @@ dependencies = [ ...@@ -1285,8 +1285,8 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"tokio", "tokio",
"triton-distributed", "triton-distributed-llm",
"triton-llm", "triton-distributed-runtime",
] ]
[[package]] [[package]]
...@@ -1732,8 +1732,8 @@ dependencies = [ ...@@ -1732,8 +1732,8 @@ dependencies = [
"tabled", "tabled",
"tokio", "tokio",
"tracing", "tracing",
"triton-distributed", "triton-distributed-llm",
"triton-llm", "triton-distributed-runtime",
] ]
[[package]] [[package]]
...@@ -2897,7 +2897,7 @@ version = "0.2.0" ...@@ -2897,7 +2897,7 @@ version = "0.2.0"
dependencies = [ dependencies = [
"futures", "futures",
"tokio", "tokio",
"triton-distributed", "triton-distributed-runtime",
] ]
[[package]] [[package]]
...@@ -3557,85 +3557,85 @@ dependencies = [ ...@@ -3557,85 +3557,85 @@ dependencies = [
] ]
[[package]] [[package]]
name = "triton-distributed" name = "triton-distributed-llm"
version = "0.2.0" version = "0.2.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-nats",
"async-once-cell",
"async-stream", "async-stream",
"async-trait", "async-trait",
"async_zmq", "axum 0.8.1",
"blake3", "blake3",
"bs62",
"bytes", "bytes",
"chrono", "chrono",
"derive-getters",
"derive_builder", "derive_builder",
"educe",
"either", "either",
"etcd-client", "erased-serde",
"figment",
"futures", "futures",
"humantime", "galil-seiferas",
"local-ip-address", "indexmap 2.7.1",
"log", "itertools 0.14.0",
"nid", "minijinja",
"nix", "minijinja-contrib",
"nuid",
"once_cell",
"prometheus", "prometheus",
"rand",
"regex", "regex",
"semver",
"serde", "serde",
"serde_json", "serde_json",
"socket2", "thiserror 2.0.11",
"thiserror 1.0.69", "tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tokio-util", "tokio-util",
"toktrie",
"toktrie_hf_tokenizers",
"tracing", "tracing",
"tracing-subscriber", "triton-distributed-runtime",
"unicode-segmentation",
"uuid", "uuid",
"validator", "validator",
"xxhash-rust", "xxhash-rust",
] ]
[[package]] [[package]]
name = "triton-llm" name = "triton-distributed-runtime"
version = "0.2.0" version = "0.2.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-nats",
"async-once-cell",
"async-stream", "async-stream",
"async-trait", "async-trait",
"axum 0.8.1", "async_zmq",
"blake3", "blake3",
"bs62",
"bytes", "bytes",
"chrono", "chrono",
"derive-getters",
"derive_builder", "derive_builder",
"educe",
"either", "either",
"erased-serde", "etcd-client",
"figment",
"futures", "futures",
"galil-seiferas", "humantime",
"indexmap 2.7.1", "local-ip-address",
"itertools 0.14.0", "log",
"minijinja", "nid",
"minijinja-contrib", "nix",
"nuid",
"once_cell",
"prometheus", "prometheus",
"rand",
"regex", "regex",
"semver",
"serde", "serde",
"serde_json", "serde_json",
"thiserror 2.0.11", "socket2",
"tokenizers", "thiserror 1.0.69",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tokio-util", "tokio-util",
"toktrie",
"toktrie_hf_tokenizers",
"tracing", "tracing",
"triton-distributed", "tracing-subscriber",
"unicode-segmentation",
"uuid", "uuid",
"validator", "validator",
"xxhash-rust", "xxhash-rust",
......
...@@ -13,15 +13,14 @@ ...@@ -13,15 +13,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::sync::Arc;
use clap::Parser; use clap::Parser;
use std::env; use std::sync::Arc;
use triton_distributed_runtime::{logging, DistributedRuntime, Result, Runtime, Worker};
use triton_distributed_llm::http::service::{ use triton_distributed_llm::http::service::{
discovery::{model_watcher, ModelWatchState}, discovery::{model_watcher, ModelWatchState},
service_v2::HttpService, service_v2::HttpService,
}; };
use triton_distributed_runtime::{logging, DistributedRuntime, Result, Runtime, Worker};
#[derive(Parser)] #[derive(Parser)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
...@@ -69,7 +68,9 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -69,7 +68,9 @@ async fn app(runtime: Runtime) -> Result<()> {
// written to etcd // written to etcd
// the cli when operating on an `http` component will validate the namespace.component is // the cli when operating on an `http` component will validate the namespace.component is
// registered with HttpServiceComponentDefinition // registered with HttpServiceComponentDefinition
let component = distributed.namespace(&args.namespace)?.component(&args.component)?; let component = distributed
.namespace(&args.namespace)?
.component(&args.component)?;
let etcd_root = component.etcd_path(); let etcd_root = component.etcd_path();
let etcd_path = format!("{}/models/chat/", etcd_root); let etcd_path = format!("{}/models/chat/", etcd_root);
......
...@@ -16,11 +16,11 @@ ...@@ -16,11 +16,11 @@
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use tracing as log; use tracing as log;
use triton_distributed_llm::http::service::discovery::ModelEntry;
use triton_distributed_runtime::{ use triton_distributed_runtime::{
distributed::DistributedConfig, logging, protocols::Endpoint, raise, DistributedRuntime, distributed::DistributedConfig, logging, protocols::Endpoint, raise, DistributedRuntime,
Result, Runtime, Worker, Result, Runtime, Worker,
}; };
use triton_distributed_llm::http::service::discovery::ModelEntry;
#[derive(Parser)] #[derive(Parser)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
......
...@@ -17,10 +17,8 @@ use futures::StreamExt; ...@@ -17,10 +17,8 @@ use futures::StreamExt;
use service_metrics::DEFAULT_NAMESPACE; use service_metrics::DEFAULT_NAMESPACE;
use triton_distributed_runtime::{ use triton_distributed_runtime::{
logging, logging, protocols::annotated::Annotated, utils::Duration, DistributedRuntime, Result, Runtime,
protocols::annotated::Annotated, Worker,
utils::{stream, Duration, Instant},
DistributedRuntime, Result, Runtime, Worker,
}; };
fn main() -> Result<()> { fn main() -> Result<()> {
......
...@@ -21,10 +21,10 @@ use std::sync::atomic::{AtomicU32, Ordering}; ...@@ -21,10 +21,10 @@ use std::sync::atomic::{AtomicU32, Ordering};
use tracing as log; use tracing as log;
use uuid::Uuid; use uuid::Uuid;
use triton_distributed_runtime::{DistributedRuntime, Worker};
use triton_distributed_llm::kv_router::{ use triton_distributed_llm::kv_router::{
indexer::compute_block_hash_for_seq, protocols::*, publisher::KvPublisher, indexer::compute_block_hash_for_seq, protocols::*, publisher::KvPublisher,
}; };
use triton_distributed_runtime::{DistributedRuntime, Worker};
static WK: OnceCell<Worker> = OnceCell::new(); static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new(); static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls? // [FIXME] shouldn't the publisher be instance passing between API calls?
......
...@@ -3635,51 +3635,7 @@ dependencies = [ ...@@ -3635,51 +3635,7 @@ dependencies = [
] ]
[[package]] [[package]]
name = "triton-distributed" name = "triton-distributed-llm"
version = "0.2.0"
dependencies = [
"anyhow",
"async-nats",
"async-once-cell",
"async-stream",
"async-trait",
"async_zmq",
"blake3",
"bytes",
"chrono",
"derive-getters",
"derive_builder",
"educe",
"either",
"etcd-client",
"figment",
"futures",
"humantime",
"local-ip-address",
"log",
"nid",
"nix",
"nuid",
"once_cell",
"prometheus",
"rand",
"regex",
"serde",
"serde_json",
"socket2",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"tracing-subscriber",
"uuid",
"validator",
"xxhash-rust",
]
[[package]]
name = "triton-llm"
version = "0.2.0" version = "0.2.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
...@@ -3712,7 +3668,7 @@ dependencies = [ ...@@ -3712,7 +3668,7 @@ dependencies = [
"toktrie", "toktrie",
"toktrie_hf_tokenizers", "toktrie_hf_tokenizers",
"tracing", "tracing",
"triton-distributed", "triton-distributed-runtime",
"unicode-segmentation", "unicode-segmentation",
"uuid", "uuid",
"validator", "validator",
...@@ -3720,7 +3676,7 @@ dependencies = [ ...@@ -3720,7 +3676,7 @@ dependencies = [
] ]
[[package]] [[package]]
name = "triton_distributed_py3" name = "triton-distributed-py3"
version = "0.2.0" version = "0.2.0"
dependencies = [ dependencies = [
"futures", "futures",
...@@ -3735,8 +3691,52 @@ dependencies = [ ...@@ -3735,8 +3691,52 @@ dependencies = [
"tokio-stream", "tokio-stream",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"triton-distributed", "triton-distributed-llm",
"triton-llm", "triton-distributed-runtime",
]
[[package]]
name = "triton-distributed-runtime"
version = "0.2.0"
dependencies = [
"anyhow",
"async-nats",
"async-once-cell",
"async-stream",
"async-trait",
"async_zmq",
"blake3",
"bytes",
"chrono",
"derive-getters",
"derive_builder",
"educe",
"either",
"etcd-client",
"figment",
"futures",
"humantime",
"local-ip-address",
"log",
"nid",
"nix",
"nuid",
"once_cell",
"prometheus",
"rand",
"regex",
"serde",
"serde_json",
"socket2",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"tracing-subscriber",
"uuid",
"validator",
"xxhash-rust",
] ]
[[package]] [[package]]
......
...@@ -48,15 +48,13 @@ const DEFAULT_ANNOTATED_SETTING: Option<bool> = Some(true); ...@@ -48,15 +48,13 @@ const DEFAULT_ANNOTATED_SETTING: Option<bool> = Some(true);
/// import the module. /// import the module.
#[pymodule] #[pymodule]
fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
// Sets up RUST_LOG environment variable for logging through the python-wheel // Sets up RUST_LOG environment variable for logging through the python-wheel
// Example: RUST_LOG=debug python3 -m ... // Example: RUST_LOG=debug python3 -m ...
let subscriber = FmtSubscriber::builder() let subscriber = FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.finish(); .finish();
tracing::subscriber::set_global_default(subscriber) tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
.expect("setting default subscriber failed");
m.add_class::<DistributedRuntime>()?; m.add_class::<DistributedRuntime>()?;
m.add_class::<CancellationToken>()?; m.add_class::<CancellationToken>()?;
......
...@@ -2569,27 +2569,6 @@ dependencies = [ ...@@ -2569,27 +2569,6 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "libtriton-llm"
version = "0.1.1"
dependencies = [
"anyhow",
"async-once-cell",
"cbindgen",
"futures",
"libc",
"once_cell",
"serde",
"serde_json",
"tokio",
"tokio-stream",
"tracing",
"tracing-subscriber",
"triton-distributed",
"triton-llm",
"uuid 1.13.1",
]
[[package]] [[package]]
name = "linked-hash-map" name = "linked-hash-map"
version = "0.5.6" version = "0.5.6"
...@@ -5427,93 +5406,93 @@ dependencies = [ ...@@ -5427,93 +5406,93 @@ dependencies = [
] ]
[[package]] [[package]]
name = "triton-distributed" name = "triton-distributed-llm"
version = "0.2.0" version = "0.2.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-nats",
"async-once-cell",
"async-stream", "async-stream",
"async-trait", "async-trait",
"async_zmq", "axum 0.8.1",
"blake3", "blake3",
"bs62",
"bytes", "bytes",
"chrono", "chrono",
"derive-getters",
"derive_builder", "derive_builder",
"educe",
"either", "either",
"etcd-client", "erased-serde",
"figment",
"futures", "futures",
"humantime", "galil-seiferas",
"local-ip-address", "hf-hub 0.4.1",
"log", "indexmap 2.7.1",
"nid", "insta",
"nix", "itertools 0.14.0",
"nuid", "minijinja",
"once_cell", "minijinja-contrib",
"mistralrs",
"prometheus", "prometheus",
"rand", "proptest",
"regex", "regex",
"reqwest",
"rstest",
"semver",
"sentencepiece",
"serde", "serde",
"serde_json", "serde_json",
"socket2", "tempfile",
"thiserror 1.0.69", "thiserror 2.0.11",
"tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tokio-util", "tokio-util",
"toktrie 0.6.28",
"toktrie_hf_tokenizers 0.6.28",
"tracing", "tracing",
"tracing-subscriber", "triton-distributed-runtime",
"unicode-segmentation",
"uuid 1.13.1", "uuid 1.13.1",
"validator", "validator",
"xxhash-rust", "xxhash-rust",
] ]
[[package]] [[package]]
name = "triton-llm" name = "triton-distributed-runtime"
version = "0.2.0" version = "0.2.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-nats",
"async-once-cell",
"async-stream", "async-stream",
"async-trait", "async-trait",
"axum 0.8.1", "async_zmq",
"blake3", "blake3",
"bs62",
"bytes", "bytes",
"chrono", "chrono",
"derive-getters",
"derive_builder", "derive_builder",
"educe",
"either", "either",
"erased-serde", "etcd-client",
"figment",
"futures", "futures",
"galil-seiferas", "humantime",
"hf-hub 0.4.1", "local-ip-address",
"indexmap 2.7.1", "log",
"insta", "nid",
"itertools 0.14.0", "nix",
"minijinja", "nuid",
"minijinja-contrib", "once_cell",
"mistralrs",
"prometheus", "prometheus",
"proptest", "rand",
"regex", "regex",
"reqwest",
"rstest",
"semver",
"sentencepiece",
"serde", "serde",
"serde_json", "serde_json",
"tempfile", "socket2",
"thiserror 2.0.11", "thiserror 1.0.69",
"tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tokio-util", "tokio-util",
"toktrie 0.6.28",
"toktrie_hf_tokenizers 0.6.28",
"tracing", "tracing",
"triton-distributed", "tracing-subscriber",
"unicode-segmentation",
"uuid 1.13.1", "uuid 1.13.1",
"validator", "validator",
"xxhash-rust", "xxhash-rust",
......
...@@ -87,7 +87,7 @@ impl Backend { ...@@ -87,7 +87,7 @@ impl Backend {
let tokenizer = match &mdc.tokenizer { let tokenizer = match &mdc.tokenizer {
TokenizerKind::HfTokenizerJson(file) => { TokenizerKind::HfTokenizerJson(file) => {
HfTokenizer::from_file(&file).map_err(Error::msg)? HfTokenizer::from_file(file).map_err(Error::msg)?
} }
}; };
......
...@@ -72,7 +72,7 @@ impl OpenAIPreprocessor { ...@@ -72,7 +72,7 @@ impl OpenAIPreprocessor {
let PromptFormatter::OAI(formatter) = formatter; let PromptFormatter::OAI(formatter) = formatter;
let tokenizer = match &mdc.tokenizer { let tokenizer = match &mdc.tokenizer {
TokenizerKind::HfTokenizerJson(file) => HuggingFaceTokenizer::from_file(&file)?, TokenizerKind::HfTokenizerJson(file) => HuggingFaceTokenizer::from_file(file)?,
}; };
let tokenizer = Arc::new(tokenizer); let tokenizer = Arc::new(tokenizer);
...@@ -109,7 +109,7 @@ impl OpenAIPreprocessor { ...@@ -109,7 +109,7 @@ impl OpenAIPreprocessor {
let use_raw_prompt = request let use_raw_prompt = request
.nvext() .nvext()
.map_or(false, |ext| ext.use_raw_prompt.unwrap_or(false)); .is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false));
let formatted_prompt = if use_raw_prompt { let formatted_prompt = if use_raw_prompt {
match request.raw_prompt() { match request.raw_prompt() {
......
...@@ -57,6 +57,7 @@ impl PromptFormatter { ...@@ -57,6 +57,7 @@ impl PromptFormatter {
/// 2. Map template: Contains 'tool_use' and/or 'default' templates /// 2. Map template: Contains 'tool_use' and/or 'default' templates
/// - tool_use: Template for tool-based interactions /// - tool_use: Template for tool-based interactions
/// - default: Template for standard chat interactions /// - default: Template for standard chat interactions
///
/// If the map contains both keys, the `tool_use` template is registered as the `tool_use` template /// If the map contains both keys, the `tool_use` template is registered as the `tool_use` template
/// and the `default` template is registered as the `default` template. /// and the `default` template is registered as the `default` template.
struct JinjaEnvironment { struct JinjaEnvironment {
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#### HuggingFace Tokenizer #### HuggingFace Tokenizer
```rust ```rust
use triton_llm::tokenizers::hf::HuggingFaceTokenizer; use triton_distributed_llm::tokenizers::hf::HuggingFaceTokenizer;
let hf_tokenizer = HuggingFaceTokenizer::from_file("tests/data/sample-models/TinyLlama_v1.1/tokenizer.json") let hf_tokenizer = HuggingFaceTokenizer::from_file("tests/data/sample-models/TinyLlama_v1.1/tokenizer.json")
.expect("Failed to load HuggingFace tokenizer"); .expect("Failed to load HuggingFace tokenizer");
...@@ -22,7 +22,7 @@ let hf_tokenizer = HuggingFaceTokenizer::from_file("tests/data/sample-models/Tin ...@@ -22,7 +22,7 @@ let hf_tokenizer = HuggingFaceTokenizer::from_file("tests/data/sample-models/Tin
### Encoding and Decoding Text ### Encoding and Decoding Text
```rust ```rust
use triton_llm::tokenizers::{HuggingFaceTokenizer, traits::{Encoder, Decoder}}; use triton_distributed_llm::tokenizers::{HuggingFaceTokenizer, traits::{Encoder, Decoder}};
let tokenizer = HuggingFaceTokenizer::from_file("tests/data/sample-models/TinyLlama_v1.1/tokenizer.json") let tokenizer = HuggingFaceTokenizer::from_file("tests/data/sample-models/TinyLlama_v1.1/tokenizer.json")
.expect("Failed to load HuggingFace tokenizer"); .expect("Failed to load HuggingFace tokenizer");
...@@ -40,7 +40,7 @@ assert_eq!(text, decoded_text); ...@@ -40,7 +40,7 @@ assert_eq!(text, decoded_text);
// Using the Sequence object for encoding and decoding // Using the Sequence object for encoding and decoding
use triton_llm::tokenizers::{Sequence, Tokenizer}; use triton_distributed_llm::tokenizers::{Sequence, Tokenizer};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
let tokenizer = Tokenizer::from(Arc::new(tokenizer)); let tokenizer = Tokenizer::from(Arc::new(tokenizer));
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use triton_llm::backend::Backend; use triton_distributed_llm::backend::Backend;
use triton_llm::model_card::model::ModelDeploymentCard; use triton_distributed_llm::model_card::model::ModelDeploymentCard;
#[tokio::test] #[tokio::test]
async fn test_sequence_factory() { async fn test_sequence_factory() {
......
...@@ -18,12 +18,6 @@ use async_stream::stream; ...@@ -18,12 +18,6 @@ use async_stream::stream;
use prometheus::{proto::MetricType, Registry}; use prometheus::{proto::MetricType, Registry};
use reqwest::StatusCode; use reqwest::StatusCode;
use std::sync::Arc; use std::sync::Arc;
use triton_distributed_runtime::{
pipeline::{
async_trait, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn,
},
CancellationToken,
};
use triton_distributed_llm::http::service::{ use triton_distributed_llm::http::service::{
error::HttpError, error::HttpError,
metrics::{Endpoint, RequestType, Status}, metrics::{Endpoint, RequestType, Status},
...@@ -37,6 +31,12 @@ use triton_distributed_llm::protocols::{ ...@@ -37,6 +31,12 @@ use triton_distributed_llm::protocols::{
}, },
Annotated, Annotated,
}; };
use triton_distributed_runtime::{
pipeline::{
async_trait, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn,
},
CancellationToken,
};
struct CounterEngine {} struct CounterEngine {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment