Commit 08fcd7e9 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

refactor: move libs to lib dir


Signed-off-by: default avatarNeelay Shah <neelays@nvidia.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 0bfd9a76
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
import asyncio import asyncio
from protocol import Request from protocol import Request
from triton_distributed_rs import DistributedRuntime, triton_worker
from triton_distributed.runtime import DistributedRuntime, triton_worker
@triton_worker() @triton_worker()
......
...@@ -18,7 +18,12 @@ import asyncio ...@@ -18,7 +18,12 @@ import asyncio
import uvloop import uvloop
from protocol import Request, Response from protocol import Request, Response
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
triton_worker,
)
uvloop.install() uvloop.install()
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
[project] [project]
name = "triton-distributed-rs" name = "triton-distributed"
version = "0.2.1" version = "0.2.1"
description = "Distributed LLM Framework" description = "Distributed Inference Framework"
readme = "README.md" readme = "README.md"
authors = [ authors = [
{ name = "NVIDIA Inc.", email = "sw-dl-triton@nvidia.com" }, { name = "NVIDIA Inc.", email = "sw-dl-triton@nvidia.com" },
...@@ -29,13 +29,10 @@ dependencies = [ ...@@ -29,13 +29,10 @@ dependencies = [
"uvloop>=0.21.0", "uvloop>=0.21.0",
] ]
# [project.scripts]
# triton-distributed = "triton_distributed_rs:main"
[tool.maturin] [tool.maturin]
module-name = "triton_distributed_rs._core" module-name = "triton_distributed._core"
python-packages = ["triton_distributed_rs"] python-packages = ["triton_distributed"]
python-source = "python" python-source = "src"
[build-system] [build-system]
requires = ["maturin>=1.0,<2.0", "patchelf"] requires = ["maturin>=1.0,<2.0", "patchelf"]
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
use std::sync::Arc; use std::sync::Arc;
pub use serde::{Deserialize, Serialize}; pub use serde::{Deserialize, Serialize};
pub use triton_distributed::{ pub use triton_distributed_runtime::{
error, error,
pipeline::{ pipeline::{
async_trait, AsyncEngine, AsyncEngineContextProvider, Data, ManyOut, ResponseStream, async_trait, AsyncEngine, AsyncEngineContextProvider, Data, ManyOut, ResponseStream,
......
...@@ -24,14 +24,14 @@ use std::{fmt::Display, sync::Arc}; ...@@ -24,14 +24,14 @@ use std::{fmt::Display, sync::Arc};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing_subscriber::FmtSubscriber; use tracing_subscriber::FmtSubscriber;
use triton_distributed::{ use triton_distributed_runtime::{
self as rs, self as rs,
pipeline::{EngineStream, ManyOut, SingleIn}, pipeline::{EngineStream, ManyOut, SingleIn},
protocols::annotated::Annotated as RsAnnotated, protocols::annotated::Annotated as RsAnnotated,
traits::DistributedRuntimeProvider, traits::DistributedRuntimeProvider,
}; };
use triton_llm::{self as llm_rs}; use triton_distributed_llm::{self as llm_rs};
mod engine; mod engine;
mod llm; mod llm;
...@@ -69,8 +69,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -69,8 +69,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
engine::add_to_module(m)?; engine::add_to_module(m)?;
// llm::http::add_to_module(m)?;
Ok(()) Ok(())
} }
......
...@@ -89,7 +89,7 @@ class Endpoint: ...@@ -89,7 +89,7 @@ class Endpoint:
""" """
... ...
async def client() -> Client: async def client(self) -> Client:
""" """
Create a `Client` capable of calling served instances of this endpoint Create a `Client` capable of calling served instances of this endpoint
""" """
...@@ -133,7 +133,7 @@ class KvRouter: ...@@ -133,7 +133,7 @@ class KvRouter:
... ...
def __init__(self, drt: DistributedRuntime, component: Component) -> KvRouter: def __init__(self, drt: DistributedRuntime, component: Component) -> None:
""" """
Create a `KvRouter` object that is associated with the `component` Create a `KvRouter` object that is associated with the `component`
""" """
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from triton_distributed._core import KvRouter as KvRouter
...@@ -19,8 +19,8 @@ from functools import wraps ...@@ -19,8 +19,8 @@ from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type from typing import Any, AsyncGenerator, Callable, Type
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from triton_distributed_rs._core import DistributedRuntime
from triton_distributed_rs._core import KvRouter as KvRouter from triton_distributed._core import DistributedRuntime
def triton_worker(): def triton_worker():
...@@ -63,18 +63,18 @@ def triton_endpoint( ...@@ -63,18 +63,18 @@ def triton_endpoint(
# Validate the request # Validate the request
try: try:
if len(args) in [1, 2]: if len(args) in [1, 2]:
args = list(args) args_list = list(args)
if isinstance(args[-1], str): if isinstance(args[-1], str):
args[-1] = request_model.parse_raw(args[-1]) args_list[-1] = request_model.parse_raw(args[-1])
elif isinstance(args[-1], dict): elif isinstance(args[-1], dict):
args[-1] = request_model.parse_obj(args[-1]) args_list[-1] = request_model.parse_obj(args[-1])
else: else:
raise ValueError(f"Invalid request: {args[-1]}") raise ValueError(f"Invalid request: {args[-1]}")
except ValidationError as e: except ValidationError as e:
raise ValueError(f"Invalid request: {e}") raise ValueError(f"Invalid request: {e}")
# Wrap the async generator # Wrap the async generator
async for item in func(*args, **kwargs): async for item in func(*args_list, **kwargs):
# Validate the response # Validate the response
# TODO: Validate the response # TODO: Validate the response
try: try:
......
...@@ -18,7 +18,8 @@ import random ...@@ -18,7 +18,8 @@ import random
import string import string
import uvloop import uvloop
from triton_distributed_rs import DistributedRuntime, triton_worker
from triton_distributed.runtime import DistributedRuntime, triton_worker
# Soak Test # Soak Test
# #
......
...@@ -20,7 +20,7 @@ pytestmark = pytest.mark.pre_merge ...@@ -20,7 +20,7 @@ pytestmark = pytest.mark.pre_merge
def test_bindings_install(): def test_bindings_install():
# Verify python bindings to rust can be imported # Verify python bindings to rust can be imported
import triton_distributed_rs as tdr import triton_distributed.runtime as tdr
# Placeholder to avoid unused import errors or removal by linters # Placeholder to avoid unused import errors or removal by linters
assert tdr assert tdr
...@@ -13,8 +13,16 @@ ...@@ -13,8 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
[workspace.package]
version = "0.2.0"
edition = "2021"
authors = ["NVIDIA"]
license = "Apache-2.0"
homepage = "https://github.com/triton-inference-server/triton_distributed"
repository = "https://github.com/triton-inference-server/triton_distributed"
[package] [package]
name = "triton-llm" name = "triton-distributed-llm"
version.workspace = true version.workspace = true
edition.workspace = true edition.workspace = true
authors.workspace = true authors.workspace = true
...@@ -27,10 +35,31 @@ metal = ["mistralrs/metal"] ...@@ -27,10 +35,31 @@ metal = ["mistralrs/metal"]
cuda = ["mistralrs/cuda"] cuda = ["mistralrs/cuda"]
sentencepiece = ["dep:sentencepiece"] sentencepiece = ["dep:sentencepiece"]
[workspace.dependencies]
# local or crates.io
triton-distributed-runtime = { version = "0.2.0", path = "../runtime" }
# crates.io
anyhow = { version = "1" }
async-stream = { version = "0.3" }
async-trait = { version = "0.1" }
bytes = "1"
derive_builder = "0.20"
futures = "0.3"
serde = { version = "1", features = ["derive"] }
thiserror = { version = "2.0.11" }
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1" }
tokio-util = { version = "0.7", features = ["codec", "net"] }
tracing = { version = "0.1" }
validator = { version = "0.20.0", features = ["derive"] }
uuid = { version = "1", features = ["v4", "serde"] }
xxhash-rust = { version = "0.8", features = ["xxh3", "const_xxh3"] }
[dependencies] [dependencies]
# repo # repo
triton-distributed = { workspace = true } triton-distributed-runtime = { workspace = true }
# workspace # workspace
anyhow = { workspace = true } anyhow = { workspace = true }
...@@ -92,7 +121,6 @@ minijinja = { version = "2.3.1", features = ["loader"] } ...@@ -92,7 +121,6 @@ minijinja = { version = "2.3.1", features = ["loader"] }
minijinja-contrib = { version = "2.3.1", features = ["pycompat"] } minijinja-contrib = { version = "2.3.1", features = ["pycompat"] }
semver = { version = "1", features = ["serde"] } semver = { version = "1", features = ["serde"] }
[dev-dependencies] [dev-dependencies]
proptest = "1.5.0" proptest = "1.5.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] } reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
......
...@@ -34,7 +34,7 @@ use futures::stream::{self, StreamExt}; ...@@ -34,7 +34,7 @@ use futures::stream::{self, StreamExt};
use tracing as log; use tracing as log;
use crate::model_card::model::{ModelDeploymentCard, TokenizerKind}; use crate::model_card::model::{ModelDeploymentCard, TokenizerKind};
use triton_distributed::{ use triton_distributed_runtime::{
pipeline::{ pipeline::{
async_trait, AsyncEngineContextProvider, ManyOut, Operator, ResponseStream, async_trait, AsyncEngineContextProvider, ManyOut, Operator, ResponseStream,
ServerStreamingEngine, SingleIn, ServerStreamingEngine, SingleIn,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment