"container/vscode:/vscode.git/clone" did not exist on "c3b847901099bf5c3dd174a3c8ec994b73426833"
Commit a32cdad6 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: add python binding for rust llm modules (#13)

parent 437d8e37
......@@ -40,6 +40,7 @@ CMakeCache.txt
**/__pycache__
**/venv
**/.venv
*.cache
# Examples
......@@ -73,3 +74,5 @@ __pycache__/
*.py[cod]
*$py.class
*.so
**/.devcontainer
\ No newline at end of file
......@@ -342,8 +342,59 @@ curl localhost:8080/v1/chat/completions -H "Content-Type: application/json"
"system_fingerprint": null
}
```
### 6. Preprocessor and backend
### 6. Known Issues and Limitations
This deployment splits the pre-processing and backend for model serving.
Run following commands in 4 terminals:
**Terminal 1 - vLLM Worker:**
```bash
# Activate virtual environment
source /opt/triton/venv/bin/activate
cd /workspace/examples/python_rs/llm/vllm
RUST_LOG=info python3 -m preprocessor.worker --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
```
**Terminal 2 - preprocessor:**
```bash
# Activate virtual environment
source /opt/triton/venv/bin/activate
cd /workspace/examples/python_rs/llm/vllm
RUST_LOG=info python3 -m preprocessor.processor --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
```
**Terminal 3 - HTTP Server**
Run the server logging (with debug level logging):
```bash
TRD_LOG=DEBUG http
```
By default the server will run on port 8080.
Add model to the server:
```bash
llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B triton-init.preprocessor.generate
```
**Terminal 4 - client**
```bash
curl localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"messages": [
{"role": "user", "content": "What is the capital of France?"}
]
}'
```
### 7. Known Issues and Limitations
- vLLM is not working well with the `fork` method for multiprocessing and TP > 1. This is a known issue and a workaround is to use the `spawn` method instead. See [vLLM issue](https://github.com/vllm-project/vllm/issues/6152).
- `kv_rank` of `kv_producer` must be smaller than of `kv_consumer`.
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from huggingface_hub import snapshot_download
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
@dataclass
class NvAsyncEngineArgs(AsyncEngineArgs):
model_path: str = field(default="")
def parse_vllm_args() -> NvAsyncEngineArgs:
parser = FlexibleArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument(
"--model-path",
type=str,
default="",
)
args = parser.parse_args()
if args.model_path == "":
if os.environ.get("HF_TOKEN"):
args.model_path = snapshot_download(args.model)
else:
raise ValueError(
"Please set HF_TOKEN environment variable "
"or pass --model-path to load the model"
)
return NvAsyncEngineArgs.from_cli_args(args)
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
PORT=8080
# list models
echo "\n\n### Listing models"
curl http://localhost:$PORT/v1/models
# create completion
echo "\n\n### Creating completions"
curl -X 'POST' \
"http://localhost:$PORT/v1/chat/completions" \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"messages": [
{
"role":"user",
"content":"what is deep learning?"
}
],
"max_tokens": 64,
"stream": true,
"temperature": 0.7,
"top_p": 0.9,
"frequency_penalty": 0.1,
"presence_penalty": 0.2,
"top_k": 5
}'
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uvloop
from preprocessor.common import parse_vllm_args
from triton_distributed.runtime import (
DistributedRuntime,
ModelDeploymentCard,
OAIChatPreprocessor,
triton_worker,
)
uvloop.install()
@triton_worker()
async def preprocessor(runtime: DistributedRuntime, model_name: str, model_path: str):
# create model deployment card
mdc = await ModelDeploymentCard.from_local_path(model_path, model_name)
# create preprocessor endpoint
component = runtime.namespace("triton-init").component("preprocessor")
await component.create_service()
endpoint = component.endpoint("generate")
# create backend endpoint
backend = runtime.namespace("triton-init").component("backend").endpoint("generate")
# start preprocessor service with next backend
chat = OAIChatPreprocessor(mdc, endpoint, next=backend)
await chat.start()
if __name__ == "__main__":
args = parse_vllm_args()
asyncio.run(preprocessor(args.model, args.model_path))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import inspect
import uuid
from contextlib import AsyncContextDecorator
from typing import Any
import uvloop
from preprocessor.common import NvAsyncEngineArgs, parse_vllm_args
from vllm import SamplingParams
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.outputs import CompletionOutput
from triton_distributed.runtime import (
Backend,
DistributedRuntime,
ModelDeploymentCard,
triton_endpoint,
triton_worker,
)
finish_reason_map = {
None: None,
"stop": "stop",
"abort": "cancelled",
"length": "length",
"error": "error",
}
class DeltaState:
"""
The vLLM AsyncEngine returns the full internal state of each slot per forward pass.
The OpenAI ChatCompletionResponseDelta object only requires the delta, so this object
is used to track the state of the last forward pass to calculate the delta.
"""
def __init__(self):
self.token_ids = None
self.last_token_count = 0
def delta(self, choice):
self.token_ids = choice.token_ids
tokens_produced = len(choice.token_ids) - self.last_token_count
self.last_token_count = len(choice.token_ids)
return choice.token_ids[-tokens_produced:]
class VllmEngine(AsyncContextDecorator):
"""
Request handler for the generate endpoint
"""
def __init__(self, engine_args: NvAsyncEngineArgs, mdc: ModelDeploymentCard):
self.mdc = mdc
self.engine_args = engine_args
print("vllm backend started")
async def __aenter__(self):
await self.async_init()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
print("vllm backend exited")
async def async_init(self):
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args, False
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
else:
raise RuntimeError("Failed to initialize engine client")
def to_backend_output(self, response: CompletionOutput, delta_token_ids: list[int]):
return {
"token_ids": delta_token_ids,
"tokens": [],
"finish_reason": finish_reason_map.get(response.finish_reason, "stop"),
"cum_log_probs": response.cumulative_logprob,
"text": None,
}
def to_sampling_params(self, request) -> SamplingParams:
sampling_params_names = inspect.signature(SamplingParams).parameters.keys()
sampling_params = {
k: v
for k, v in request.get("sampling_options", {}).items()
if k in sampling_params_names and v is not None
}
return SamplingParams(**sampling_params)
@triton_endpoint(Any, CompletionOutput)
async def generate(self, request):
state = DeltaState()
request_id = str(uuid.uuid4())
sampling_params = self.to_sampling_params(request)
inputs = {"prompt_token_ids": request["token_ids"]}
stream = self.engine_client.generate(
inputs, sampling_params, request_id=request_id
)
async for request_output in stream:
for choice in request_output.outputs:
delta_token_ids = state.delta(choice)
yield self.to_backend_output(choice, delta_token_ids)
@triton_worker()
async def worker(runtime: DistributedRuntime, engine_args: NvAsyncEngineArgs):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("triton-init").component("backend")
await component.create_service()
endpoint = component.endpoint("generate")
mdc = await ModelDeploymentCard.from_local_path(
engine_args.model_path, engine_args.model
)
async with VllmEngine(engine_args, mdc) as engine:
backend = Backend(mdc, endpoint)
await backend.start(engine.generate)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
asyncio.run(worker(engine_args))
......@@ -67,6 +67,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<AsyncResponseStream>()?;
m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<llm::kv::KvMetricsPublisher>()?;
m.add_class::<llm::model_card::ModelDeploymentCard>()?;
m.add_class::<llm::preprocessor::OAIChatPreprocessor>()?;
m.add_class::<llm::backend::Backend>()?;
engine::add_to_module(m)?;
......
......@@ -15,4 +15,7 @@
use super::*;
pub mod backend;
pub mod kv;
pub mod model_card;
pub mod preprocessor;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use crate::llm::model_card::ModelDeploymentCard;
use llm_rs::protocols::common::llm_backend::{BackendInput, BackendOutput};
use llm_rs::types::Annotated;
use triton_distributed_runtime::pipeline::{Operator, ServiceBackend, ServiceFrontend, Source};
use crate::engine::PythonAsyncEngine;
#[pyclass]
pub(crate) struct Backend {
inner: Arc<llm_rs::backend::Backend>,
endpoint: Endpoint,
}
#[pymethods]
impl Backend {
#[new]
fn new(mdc: ModelDeploymentCard, endpoint: Endpoint) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
let backend = runtime
.block_on(llm_rs::backend::Backend::from_mdc(mdc.inner))
.map_err(to_pyerr)?;
Ok(Self {
inner: backend,
endpoint,
})
}
fn start<'p>(&self, py: Python<'p>, generator: PyObject) -> PyResult<Bound<'p, PyAny>> {
let frontend =
ServiceFrontend::<SingleIn<BackendInput>, ManyOut<Annotated<BackendOutput>>>::new();
let backend = self.inner.into_operator();
let engine = Arc::new(PythonAsyncEngine::new(
generator,
self.endpoint.event_loop.clone(),
)?);
let engine = ServiceBackend::from_engine(engine);
let pipeline = frontend
.link(backend.forward_edge())
.map_err(to_pyerr)?
.link(engine)
.map_err(to_pyerr)?
.link(backend.backward_edge())
.map_err(to_pyerr)?
.link(frontend)
.map_err(to_pyerr)?;
let ingress = Ingress::for_engine(pipeline).map_err(to_pyerr)?;
let builder = self.endpoint.inner.endpoint_builder().handler(ingress);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
builder.start().await.map_err(to_pyerr)?;
Ok(())
})
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use llm_rs::model_card::model::ModelDeploymentCard as RsModelDeploymentCard;
#[pyclass]
#[derive(Clone)]
pub(crate) struct ModelDeploymentCard {
pub(crate) inner: RsModelDeploymentCard,
}
impl ModelDeploymentCard {}
#[pymethods]
impl ModelDeploymentCard {
#[staticmethod]
fn from_local_path(
path: String,
model_name: String,
py: Python<'_>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let card = RsModelDeploymentCard::from_local_path(&path, Some(&model_name))
.await
.map_err(to_pyerr)?;
Ok(ModelDeploymentCard { inner: card })
})
}
#[staticmethod]
fn from_json_str(json: String) -> PyResult<ModelDeploymentCard> {
let card = RsModelDeploymentCard::load_from_json_str(&json).map_err(to_pyerr)?;
Ok(ModelDeploymentCard { inner: card })
}
fn to_json_str(&self) -> PyResult<String> {
let json = self.inner.to_json().map_err(to_pyerr)?;
Ok(json)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use crate::llm::model_card::ModelDeploymentCard;
use llm_rs::{
preprocessor::OpenAIPreprocessor,
protocols::common::llm_backend::{BackendInput, BackendOutput},
types::{
openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
},
Annotated,
},
};
use triton_distributed_runtime::pipeline::{Operator, ServiceFrontend, Source};
use triton_distributed_runtime::pipeline::{ManyOut, SegmentSink, SingleIn};
#[pyclass]
pub(crate) struct OAIChatPreprocessor {
inner: Arc<llm_rs::preprocessor::OpenAIPreprocessor>,
current: Endpoint,
next: Endpoint,
}
#[pymethods]
impl OAIChatPreprocessor {
#[new]
fn new(mdc: ModelDeploymentCard, current: Endpoint, next: Endpoint) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
let preprocessor = runtime
.block_on(OpenAIPreprocessor::new(mdc.inner.clone()))
.map_err(to_pyerr)?;
Ok(Self {
inner: preprocessor,
current,
next,
})
}
fn start<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let frontend = ServiceFrontend::<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new();
let network =
SegmentSink::<SingleIn<BackendInput>, ManyOut<Annotated<BackendOutput>>>::new();
let preprocessor = self.inner.into_operator();
let pipeline = frontend
.link(preprocessor.forward_edge())
.map_err(to_pyerr)?
.link(network.clone())
.map_err(to_pyerr)?
.link(preprocessor.backward_edge())
.map_err(to_pyerr)?
.link(frontend)
.map_err(to_pyerr)?;
let ingress = Ingress::for_engine(pipeline).map_err(to_pyerr)?;
let builder = self.current.inner.endpoint_builder().handler(ingress);
let endpoint = Arc::new(self.next.inner.clone());
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let client = Arc::new(
endpoint
.client::<BackendInput, Annotated<BackendOutput>>()
.await
.map_err(to_pyerr)?,
);
network.attach(client).map_err(to_pyerr)?;
builder.start().await.map_err(to_pyerr)?;
Ok(())
})
}
}
......@@ -199,3 +199,37 @@ class KvMetricsPublisher:
Update the KV metrics being reported.
"""
...
class ModelDeploymentCard:
"""
A model deployment card is a collection of model information
"""
...
class OAIChatPreprocessor:
"""
A preprocessor for OpenAI chat completions
"""
...
async def start(self) -> None:
"""
Start the preprocessor
"""
...
class Backend:
"""
LLM Backend engine manages resources and concurrency for executing inference
requests in LLM engines (trtllm, vllm, sglang etc)
"""
...
async def start(self, handler: RequestHandler) -> None:
"""
Start the backend engine and requests to the downstream LLM engine
"""
...
......@@ -16,12 +16,18 @@
import asyncio
from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type
from typing import Any, AsyncGenerator, Callable, Type, Union
from pydantic import BaseModel, ValidationError
# List all the classes in the _core module for re-export
# import * causes "unable to detect undefined names"
from triton_distributed._core import Backend as Backend
from triton_distributed._core import Client as Client
from triton_distributed._core import DistributedRuntime as DistributedRuntime
from triton_distributed._core import KvRouter as KvRouter
from triton_distributed._core import ModelDeploymentCard as ModelDeploymentCard
from triton_distributed._core import OAIChatPreprocessor as OAIChatPreprocessor
def triton_worker():
......@@ -54,7 +60,7 @@ def triton_worker():
def triton_endpoint(
request_model: Type[BaseModel], response_model: Type[BaseModel]
request_model: Union[Type[BaseModel], Type[Any]], response_model: Type[BaseModel]
) -> Callable:
def decorator(
func: Callable[..., AsyncGenerator[Any, None]]
......@@ -63,8 +69,8 @@ def triton_endpoint(
async def wrapper(*args, **kwargs) -> AsyncGenerator[Any, None]:
# Validate the request
try:
if len(args) in [1, 2]:
args_list = list(args)
if len(args) in [1, 2] and issubclass(request_model, BaseModel):
if isinstance(args[-1], str):
args_list[-1] = request_model.parse_raw(args[-1])
elif isinstance(args[-1], dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment