Commit 08fcd7e9 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

refactor: move libs to lib dir


Signed-off-by: default avatarNeelay Shah <neelays@nvidia.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 0bfd9a76
......@@ -16,7 +16,8 @@
import asyncio
from protocol import Request
from triton_distributed_rs import DistributedRuntime, triton_worker
from triton_distributed.runtime import DistributedRuntime, triton_worker
@triton_worker()
......
......@@ -18,7 +18,12 @@ import asyncio
import uvloop
from protocol import Request, Response
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
triton_worker,
)
uvloop.install()
......
......@@ -15,9 +15,9 @@
[project]
name = "triton-distributed-rs"
name = "triton-distributed"
version = "0.2.1"
description = "Distributed LLM Framework"
description = "Distributed Inference Framework"
readme = "README.md"
authors = [
{ name = "NVIDIA Inc.", email = "sw-dl-triton@nvidia.com" },
......@@ -29,13 +29,10 @@ dependencies = [
"uvloop>=0.21.0",
]
# [project.scripts]
# triton-distributed = "triton_distributed_rs:main"
[tool.maturin]
module-name = "triton_distributed_rs._core"
python-packages = ["triton_distributed_rs"]
python-source = "python"
module-name = "triton_distributed._core"
python-packages = ["triton_distributed"]
python-source = "src"
[build-system]
requires = ["maturin>=1.0,<2.0", "patchelf"]
......
......@@ -16,7 +16,7 @@
use std::sync::Arc;
pub use serde::{Deserialize, Serialize};
pub use triton_distributed::{
pub use triton_distributed_runtime::{
error,
pipeline::{
async_trait, AsyncEngine, AsyncEngineContextProvider, Data, ManyOut, ResponseStream,
......
......@@ -24,14 +24,14 @@ use std::{fmt::Display, sync::Arc};
use tokio::sync::Mutex;
use tracing_subscriber::FmtSubscriber;
use triton_distributed::{
use triton_distributed_runtime::{
self as rs,
pipeline::{EngineStream, ManyOut, SingleIn},
protocols::annotated::Annotated as RsAnnotated,
traits::DistributedRuntimeProvider,
};
use triton_llm::{self as llm_rs};
use triton_distributed_llm::{self as llm_rs};
mod engine;
mod llm;
......@@ -69,8 +69,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
engine::add_to_module(m)?;
// llm::http::add_to_module(m)?;
Ok(())
}
......
......@@ -89,7 +89,7 @@ class Endpoint:
"""
...
async def client() -> Client:
async def client(self) -> Client:
"""
Create a `Client` capable of calling served instances of this endpoint
"""
......@@ -133,7 +133,7 @@ class KvRouter:
...
def __init__(self, drt: DistributedRuntime, component: Component) -> KvRouter:
def __init__(self, drt: DistributedRuntime, component: Component) -> None:
"""
Create a `KvRouter` object that is associated with the `component`
"""
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from triton_distributed._core import KvRouter as KvRouter
......@@ -19,8 +19,8 @@ from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type
from pydantic import BaseModel, ValidationError
from triton_distributed_rs._core import DistributedRuntime
from triton_distributed_rs._core import KvRouter as KvRouter
from triton_distributed._core import DistributedRuntime
def triton_worker():
......@@ -63,18 +63,18 @@ def triton_endpoint(
# Validate the request
try:
if len(args) in [1, 2]:
args = list(args)
args_list = list(args)
if isinstance(args[-1], str):
args[-1] = request_model.parse_raw(args[-1])
args_list[-1] = request_model.parse_raw(args[-1])
elif isinstance(args[-1], dict):
args[-1] = request_model.parse_obj(args[-1])
args_list[-1] = request_model.parse_obj(args[-1])
else:
raise ValueError(f"Invalid request: {args[-1]}")
except ValidationError as e:
raise ValueError(f"Invalid request: {e}")
# Wrap the async generator
async for item in func(*args, **kwargs):
async for item in func(*args_list, **kwargs):
# Validate the response
# TODO: Validate the response
try:
......
......@@ -18,7 +18,8 @@ import random
import string
import uvloop
from triton_distributed_rs import DistributedRuntime, triton_worker
from triton_distributed.runtime import DistributedRuntime, triton_worker
# Soak Test
#
......
......@@ -20,7 +20,7 @@ pytestmark = pytest.mark.pre_merge
def test_bindings_install():
# Verify python bindings to rust can be imported
import triton_distributed_rs as tdr
import triton_distributed.runtime as tdr
# Placeholder to avoid unused import errors or removal by linters
assert tdr
......@@ -13,8 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
[workspace.package]
version = "0.2.0"
edition = "2021"
authors = ["NVIDIA"]
license = "Apache-2.0"
homepage = "https://github.com/triton-inference-server/triton_distributed"
repository = "https://github.com/triton-inference-server/triton_distributed"
[package]
name = "triton-llm"
name = "triton-distributed-llm"
version.workspace = true
edition.workspace = true
authors.workspace = true
......@@ -27,10 +35,31 @@ metal = ["mistralrs/metal"]
cuda = ["mistralrs/cuda"]
sentencepiece = ["dep:sentencepiece"]
[workspace.dependencies]
# local or crates.io
triton-distributed-runtime = { version = "0.2.0", path = "../runtime" }
# crates.io
anyhow = { version = "1" }
async-stream = { version = "0.3" }
async-trait = { version = "0.1" }
bytes = "1"
derive_builder = "0.20"
futures = "0.3"
serde = { version = "1", features = ["derive"] }
thiserror = { version = "2.0.11" }
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1" }
tokio-util = { version = "0.7", features = ["codec", "net"] }
tracing = { version = "0.1" }
validator = { version = "0.20.0", features = ["derive"] }
uuid = { version = "1", features = ["v4", "serde"] }
xxhash-rust = { version = "0.8", features = ["xxh3", "const_xxh3"] }
[dependencies]
# repo
triton-distributed = { workspace = true }
triton-distributed-runtime = { workspace = true }
# workspace
anyhow = { workspace = true }
......@@ -92,7 +121,6 @@ minijinja = { version = "2.3.1", features = ["loader"] }
minijinja-contrib = { version = "2.3.1", features = ["pycompat"] }
semver = { version = "1", features = ["serde"] }
[dev-dependencies]
proptest = "1.5.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
......
......@@ -34,7 +34,7 @@ use futures::stream::{self, StreamExt};
use tracing as log;
use crate::model_card::model::{ModelDeploymentCard, TokenizerKind};
use triton_distributed::{
use triton_distributed_runtime::{
pipeline::{
async_trait, AsyncEngineContextProvider, ManyOut, Operator, ResponseStream,
ServerStreamingEngine, SingleIn,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment