Commit c06b95ff authored by Ryan McCormick's avatar Ryan McCormick Committed by GitHub
Browse files

ci: Add rust checks to missing directories (#239)


Signed-off-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 5f1af25a
......@@ -26,9 +26,12 @@ on:
branches:
- main
paths:
- pre-merge-rust.yml
- 'lib/runtime/**'
- 'lib/llm/**'
- 'applications/llm/tio/**'
- 'lib/bindings/**'
- 'applications/llm/**'
- 'examples/rust/**'
- '**.rs'
- 'Cargo.toml'
- 'Cargo.lock'
......@@ -36,6 +39,8 @@ on:
jobs:
pre-merge-rust:
runs-on: ubuntu-latest
strategy:
matrix: { dir: ['lib/runtime', 'lib/llm', 'lib/bindings/c', 'lib/bindings/python', 'applications/llm/tio', 'applications/llm/count', 'examples/rust'] }
permissions:
contents: read
steps:
......@@ -64,27 +69,30 @@ jobs:
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Set up Rust Toolchain Components
run: rustup component add rustfmt clippy
- name: Run Cargo Check on runtime
working-directory: lib/runtime
run: cargo check --locked
- name: Run Cargo Check on tio
working-directory: applications/llm/tio
- name: Run Cargo Check
working-directory: ${{ matrix.dir }}
run: cargo check --locked
timeout-minutes: 5
- name: Verify Code Formatting
working-directory: lib/runtime
working-directory: ${{ matrix.dir }}
run: cargo fmt -- --check
- name: Run Clippy Checks on runtime
working-directory: lib/runtime
run: cargo clippy --no-deps --all-targets -- -D warnings
- name: Run Clippy Checks on tio
working-directory: applications/llm/tio
- name: Run Clippy Checks
working-directory: ${{ matrix.dir }}
run: cargo clippy --no-deps --all-targets -- -D warnings
- name: Install and Run cargo-deny
working-directory: lib/runtime
working-directory: ${{ matrix.dir }}
# FIXME: Skip this step for failing dirs until license errors fixed
if: |
matrix.dir != 'examples/rust' &&
matrix.dir != 'applications/llm/count' &&
matrix.dir != 'applications/llm/tio' &&
matrix.dir != 'lib/llm' &&
matrix.dir != 'lib/bindings/c' &&
matrix.dir != 'lib/bindings/python'
run: |
cargo-deny --version || cargo install cargo-deny@0.16.4
cargo-deny check --hide-inclusion-graph licenses
timeout-minutes: 5
- name: Run Unit Tests
working-directory: lib/runtime
working-directory: ${{ matrix.dir }}
run: cargo test --locked --all-targets
......@@ -13,10 +13,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use triton_distributed_llm::http::service::discovery::ModelEntry;
use triton_distributed_runtime::{
pipeline::network::Ingress, protocols::Endpoint, DistributedRuntime, Runtime,
};
use triton_distributed_llm::http::service::discovery::ModelEntry;
use crate::{EngineConfig, ENDPOINT_SCHEME};
......
......@@ -15,8 +15,8 @@
use std::sync::Arc;
use triton_distributed_runtime::{DistributedRuntime, Runtime};
use triton_distributed_llm::http::service::{discovery, service_v2};
use triton_distributed_runtime::{DistributedRuntime, Runtime};
use crate::EngineConfig;
......
......@@ -18,13 +18,13 @@ use std::{
io::{ErrorKind, Read, Write},
sync::Arc,
};
use triton_distributed_runtime::{pipeline::Context, runtime::CancellationToken};
use triton_distributed_llm::{
protocols::openai::chat_completions::MessageRole,
types::openai::chat_completions::{
ChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
},
};
use triton_distributed_runtime::{pipeline::Context, runtime::CancellationToken};
use crate::EngineConfig;
......
......@@ -15,13 +15,13 @@
use std::path::PathBuf;
use triton_distributed_runtime::{component::Client, DistributedRuntime};
use triton_distributed_llm::types::{
openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta, OpenAIChatCompletionsStreamingEngine,
},
Annotated,
};
use triton_distributed_runtime::{component::Client, DistributedRuntime};
mod input;
mod opt;
......@@ -138,7 +138,8 @@ pub async fn run(
};
EngineConfig::StaticFull {
service_name: model_name,
engine: triton_distributed_llm::engines::mistralrs::make_engine(&model_path).await?,
engine: triton_distributed_llm::engines::mistralrs::make_engine(&model_path)
.await?,
}
}
};
......
......@@ -18,14 +18,14 @@ use std::{sync::Arc, time::Duration};
use async_stream::stream;
use async_trait::async_trait;
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated;
use triton_distributed_llm::protocols::openai::chat_completions::FinishReason;
use triton_distributed_llm::protocols::openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta, Content,
};
use triton_distributed_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated;
/// How long to sleep between echoed tokens.
/// 50ms gives us 20 tok/s.
......
......@@ -1257,7 +1257,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
name = "hello_world"
version = "0.2.0"
dependencies = [
"triton-distributed",
"triton-distributed-runtime",
]
[[package]]
......@@ -1285,8 +1285,8 @@ dependencies = [
"serde",
"serde_json",
"tokio",
"triton-distributed",
"triton-llm",
"triton-distributed-llm",
"triton-distributed-runtime",
]
[[package]]
......@@ -1732,8 +1732,8 @@ dependencies = [
"tabled",
"tokio",
"tracing",
"triton-distributed",
"triton-llm",
"triton-distributed-llm",
"triton-distributed-runtime",
]
[[package]]
......@@ -2897,7 +2897,7 @@ version = "0.2.0"
dependencies = [
"futures",
"tokio",
"triton-distributed",
"triton-distributed-runtime",
]
[[package]]
......@@ -3557,85 +3557,85 @@ dependencies = [
]
[[package]]
name = "triton-distributed"
name = "triton-distributed-llm"
version = "0.2.0"
dependencies = [
"anyhow",
"async-nats",
"async-once-cell",
"async-stream",
"async-trait",
"async_zmq",
"axum 0.8.1",
"blake3",
"bs62",
"bytes",
"chrono",
"derive-getters",
"derive_builder",
"educe",
"either",
"etcd-client",
"figment",
"erased-serde",
"futures",
"humantime",
"local-ip-address",
"log",
"nid",
"nix",
"nuid",
"once_cell",
"galil-seiferas",
"indexmap 2.7.1",
"itertools 0.14.0",
"minijinja",
"minijinja-contrib",
"prometheus",
"rand",
"regex",
"semver",
"serde",
"serde_json",
"socket2",
"thiserror 1.0.69",
"thiserror 2.0.11",
"tokenizers",
"tokio",
"tokio-stream",
"tokio-util",
"toktrie",
"toktrie_hf_tokenizers",
"tracing",
"tracing-subscriber",
"triton-distributed-runtime",
"unicode-segmentation",
"uuid",
"validator",
"xxhash-rust",
]
[[package]]
name = "triton-llm"
name = "triton-distributed-runtime"
version = "0.2.0"
dependencies = [
"anyhow",
"async-nats",
"async-once-cell",
"async-stream",
"async-trait",
"axum 0.8.1",
"async_zmq",
"blake3",
"bs62",
"bytes",
"chrono",
"derive-getters",
"derive_builder",
"educe",
"either",
"erased-serde",
"etcd-client",
"figment",
"futures",
"galil-seiferas",
"indexmap 2.7.1",
"itertools 0.14.0",
"minijinja",
"minijinja-contrib",
"humantime",
"local-ip-address",
"log",
"nid",
"nix",
"nuid",
"once_cell",
"prometheus",
"rand",
"regex",
"semver",
"serde",
"serde_json",
"thiserror 2.0.11",
"tokenizers",
"socket2",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tokio-util",
"toktrie",
"toktrie_hf_tokenizers",
"tracing",
"triton-distributed",
"unicode-segmentation",
"tracing-subscriber",
"uuid",
"validator",
"xxhash-rust",
......
......@@ -13,15 +13,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use clap::Parser;
use std::env;
use std::sync::Arc;
use triton_distributed_runtime::{logging, DistributedRuntime, Result, Runtime, Worker};
use triton_distributed_llm::http::service::{
discovery::{model_watcher, ModelWatchState},
service_v2::HttpService,
};
use triton_distributed_runtime::{logging, DistributedRuntime, Result, Runtime, Worker};
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
......@@ -69,7 +68,9 @@ async fn app(runtime: Runtime) -> Result<()> {
// written to etcd
// the cli when operating on an `http` component will validate the namespace.component is
// registered with HttpServiceComponentDefinition
let component = distributed.namespace(&args.namespace)?.component(&args.component)?;
let component = distributed
.namespace(&args.namespace)?
.component(&args.component)?;
let etcd_root = component.etcd_path();
let etcd_path = format!("{}/models/chat/", etcd_root);
......
......@@ -16,11 +16,11 @@
use clap::{Parser, Subcommand};
use tracing as log;
use triton_distributed_llm::http::service::discovery::ModelEntry;
use triton_distributed_runtime::{
distributed::DistributedConfig, logging, protocols::Endpoint, raise, DistributedRuntime,
Result, Runtime, Worker,
};
use triton_distributed_llm::http::service::discovery::ModelEntry;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
......
......@@ -17,10 +17,8 @@ use futures::StreamExt;
use service_metrics::DEFAULT_NAMESPACE;
use triton_distributed_runtime::{
logging,
protocols::annotated::Annotated,
utils::{stream, Duration, Instant},
DistributedRuntime, Result, Runtime, Worker,
logging, protocols::annotated::Annotated, utils::Duration, DistributedRuntime, Result, Runtime,
Worker,
};
fn main() -> Result<()> {
......
......@@ -21,10 +21,10 @@ use std::sync::atomic::{AtomicU32, Ordering};
use tracing as log;
use uuid::Uuid;
use triton_distributed_runtime::{DistributedRuntime, Worker};
use triton_distributed_llm::kv_router::{
indexer::compute_block_hash_for_seq, protocols::*, publisher::KvPublisher,
};
use triton_distributed_runtime::{DistributedRuntime, Worker};
static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls?
......
......@@ -3635,51 +3635,7 @@ dependencies = [
]
[[package]]
name = "triton-distributed"
version = "0.2.0"
dependencies = [
"anyhow",
"async-nats",
"async-once-cell",
"async-stream",
"async-trait",
"async_zmq",
"blake3",
"bytes",
"chrono",
"derive-getters",
"derive_builder",
"educe",
"either",
"etcd-client",
"figment",
"futures",
"humantime",
"local-ip-address",
"log",
"nid",
"nix",
"nuid",
"once_cell",
"prometheus",
"rand",
"regex",
"serde",
"serde_json",
"socket2",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"tracing-subscriber",
"uuid",
"validator",
"xxhash-rust",
]
[[package]]
name = "triton-llm"
name = "triton-distributed-llm"
version = "0.2.0"
dependencies = [
"anyhow",
......@@ -3712,7 +3668,7 @@ dependencies = [
"toktrie",
"toktrie_hf_tokenizers",
"tracing",
"triton-distributed",
"triton-distributed-runtime",
"unicode-segmentation",
"uuid",
"validator",
......@@ -3720,7 +3676,7 @@ dependencies = [
]
[[package]]
name = "triton_distributed_py3"
name = "triton-distributed-py3"
version = "0.2.0"
dependencies = [
"futures",
......@@ -3735,8 +3691,52 @@ dependencies = [
"tokio-stream",
"tracing",
"tracing-subscriber",
"triton-distributed",
"triton-llm",
"triton-distributed-llm",
"triton-distributed-runtime",
]
[[package]]
name = "triton-distributed-runtime"
version = "0.2.0"
dependencies = [
"anyhow",
"async-nats",
"async-once-cell",
"async-stream",
"async-trait",
"async_zmq",
"blake3",
"bytes",
"chrono",
"derive-getters",
"derive_builder",
"educe",
"either",
"etcd-client",
"figment",
"futures",
"humantime",
"local-ip-address",
"log",
"nid",
"nix",
"nuid",
"once_cell",
"prometheus",
"rand",
"regex",
"serde",
"serde_json",
"socket2",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"tracing-subscriber",
"uuid",
"validator",
"xxhash-rust",
]
[[package]]
......
......@@ -48,15 +48,13 @@ const DEFAULT_ANNOTATED_SETTING: Option<bool> = Some(true);
/// import the module.
#[pymodule]
fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
// Sets up RUST_LOG environment variable for logging through the python-wheel
// Example: RUST_LOG=debug python3 -m ...
let subscriber = FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.finish();
tracing::subscriber::set_global_default(subscriber)
.expect("setting default subscriber failed");
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
m.add_class::<DistributedRuntime>()?;
m.add_class::<CancellationToken>()?;
......
......@@ -2569,27 +2569,6 @@ dependencies = [
"libc",
]
[[package]]
name = "libtriton-llm"
version = "0.1.1"
dependencies = [
"anyhow",
"async-once-cell",
"cbindgen",
"futures",
"libc",
"once_cell",
"serde",
"serde_json",
"tokio",
"tokio-stream",
"tracing",
"tracing-subscriber",
"triton-distributed",
"triton-llm",
"uuid 1.13.1",
]
[[package]]
name = "linked-hash-map"
version = "0.5.6"
......@@ -5427,93 +5406,93 @@ dependencies = [
]
[[package]]
name = "triton-distributed"
name = "triton-distributed-llm"
version = "0.2.0"
dependencies = [
"anyhow",
"async-nats",
"async-once-cell",
"async-stream",
"async-trait",
"async_zmq",
"axum 0.8.1",
"blake3",
"bs62",
"bytes",
"chrono",
"derive-getters",
"derive_builder",
"educe",
"either",
"etcd-client",
"figment",
"erased-serde",
"futures",
"humantime",
"local-ip-address",
"log",
"nid",
"nix",
"nuid",
"once_cell",
"galil-seiferas",
"hf-hub 0.4.1",
"indexmap 2.7.1",
"insta",
"itertools 0.14.0",
"minijinja",
"minijinja-contrib",
"mistralrs",
"prometheus",
"rand",
"proptest",
"regex",
"reqwest",
"rstest",
"semver",
"sentencepiece",
"serde",
"serde_json",
"socket2",
"thiserror 1.0.69",
"tempfile",
"thiserror 2.0.11",
"tokenizers",
"tokio",
"tokio-stream",
"tokio-util",
"toktrie 0.6.28",
"toktrie_hf_tokenizers 0.6.28",
"tracing",
"tracing-subscriber",
"triton-distributed-runtime",
"unicode-segmentation",
"uuid 1.13.1",
"validator",
"xxhash-rust",
]
[[package]]
name = "triton-llm"
name = "triton-distributed-runtime"
version = "0.2.0"
dependencies = [
"anyhow",
"async-nats",
"async-once-cell",
"async-stream",
"async-trait",
"axum 0.8.1",
"async_zmq",
"blake3",
"bs62",
"bytes",
"chrono",
"derive-getters",
"derive_builder",
"educe",
"either",
"erased-serde",
"etcd-client",
"figment",
"futures",
"galil-seiferas",
"hf-hub 0.4.1",
"indexmap 2.7.1",
"insta",
"itertools 0.14.0",
"minijinja",
"minijinja-contrib",
"mistralrs",
"humantime",
"local-ip-address",
"log",
"nid",
"nix",
"nuid",
"once_cell",
"prometheus",
"proptest",
"rand",
"regex",
"reqwest",
"rstest",
"semver",
"sentencepiece",
"serde",
"serde_json",
"tempfile",
"thiserror 2.0.11",
"tokenizers",
"socket2",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tokio-util",
"toktrie 0.6.28",
"toktrie_hf_tokenizers 0.6.28",
"tracing",
"triton-distributed",
"unicode-segmentation",
"tracing-subscriber",
"uuid 1.13.1",
"validator",
"xxhash-rust",
......
......@@ -87,7 +87,7 @@ impl Backend {
let tokenizer = match &mdc.tokenizer {
TokenizerKind::HfTokenizerJson(file) => {
HfTokenizer::from_file(&file).map_err(Error::msg)?
HfTokenizer::from_file(file).map_err(Error::msg)?
}
};
......
......@@ -72,7 +72,7 @@ impl OpenAIPreprocessor {
let PromptFormatter::OAI(formatter) = formatter;
let tokenizer = match &mdc.tokenizer {
TokenizerKind::HfTokenizerJson(file) => HuggingFaceTokenizer::from_file(&file)?,
TokenizerKind::HfTokenizerJson(file) => HuggingFaceTokenizer::from_file(file)?,
};
let tokenizer = Arc::new(tokenizer);
......@@ -109,7 +109,7 @@ impl OpenAIPreprocessor {
let use_raw_prompt = request
.nvext()
.map_or(false, |ext| ext.use_raw_prompt.unwrap_or(false));
.is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false));
let formatted_prompt = if use_raw_prompt {
match request.raw_prompt() {
......
......@@ -57,6 +57,7 @@ impl PromptFormatter {
/// 2. Map template: Contains 'tool_use' and/or 'default' templates
/// - tool_use: Template for tool-based interactions
/// - default: Template for standard chat interactions
///
/// If the map contains both keys, the `tool_use` template is registered as the `tool_use` template
/// and the `default` template is registered as the `default` template.
struct JinjaEnvironment {
......
......@@ -13,7 +13,7 @@
#### HuggingFace Tokenizer
```rust
use triton_llm::tokenizers::hf::HuggingFaceTokenizer;
use triton_distributed_llm::tokenizers::hf::HuggingFaceTokenizer;
let hf_tokenizer = HuggingFaceTokenizer::from_file("tests/data/sample-models/TinyLlama_v1.1/tokenizer.json")
.expect("Failed to load HuggingFace tokenizer");
......@@ -22,7 +22,7 @@ let hf_tokenizer = HuggingFaceTokenizer::from_file("tests/data/sample-models/Tin
### Encoding and Decoding Text
```rust
use triton_llm::tokenizers::{HuggingFaceTokenizer, traits::{Encoder, Decoder}};
use triton_distributed_llm::tokenizers::{HuggingFaceTokenizer, traits::{Encoder, Decoder}};
let tokenizer = HuggingFaceTokenizer::from_file("tests/data/sample-models/TinyLlama_v1.1/tokenizer.json")
.expect("Failed to load HuggingFace tokenizer");
......@@ -40,7 +40,7 @@ assert_eq!(text, decoded_text);
// Using the Sequence object for encoding and decoding
use triton_llm::tokenizers::{Sequence, Tokenizer};
use triton_distributed_llm::tokenizers::{Sequence, Tokenizer};
use std::sync::{Arc, RwLock};
let tokenizer = Tokenizer::from(Arc::new(tokenizer));
......
......@@ -13,8 +13,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use triton_llm::backend::Backend;
use triton_llm::model_card::model::ModelDeploymentCard;
use triton_distributed_llm::backend::Backend;
use triton_distributed_llm::model_card::model::ModelDeploymentCard;
#[tokio::test]
async fn test_sequence_factory() {
......
......@@ -18,12 +18,6 @@ use async_stream::stream;
use prometheus::{proto::MetricType, Registry};
use reqwest::StatusCode;
use std::sync::Arc;
use triton_distributed_runtime::{
pipeline::{
async_trait, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn,
},
CancellationToken,
};
use triton_distributed_llm::http::service::{
error::HttpError,
metrics::{Endpoint, RequestType, Status},
......@@ -37,6 +31,12 @@ use triton_distributed_llm::protocols::{
},
Annotated,
};
use triton_distributed_runtime::{
pipeline::{
async_trait, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn,
},
CancellationToken,
};
struct CounterEngine {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment