Commit e0bb5bd3 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

feat: LLMAPI PoC with dynamo-run launcher (#114)

parent 76b79149
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
IMPORTANT:
- This is only supposed to be used by dynamo-run launcher.
- It is part of bring-your-own-engine python feature in dynamo-run.
"""
import sys
from pathlib import Path
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionRequest,
ChatCompletionStreamResponse,
)
from dynamo.runtime import dynamo_endpoint
# Add the project root to the Python path
project_root = str(Path(__file__).parents[1]) # Go up to trtllm directory
if project_root not in sys.path:
sys.path.append(project_root)
from common.base_engine import ( # noqa: E402
BaseTensorrtLLMEngine,
TensorrtLLMEngineConfig,
)
from common.generators import chat_generator # noqa: E402
from common.parser import parse_dynamo_run_args # noqa: E402
logger.set_level("info")
class DynamoTRTLLMEngine(BaseTensorrtLLMEngine):
"""
Request handler for the generate endpoint
"""
def __init__(self, trt_llm_engine_config: TensorrtLLMEngineConfig):
super().__init__(trt_llm_engine_config)
engine = None # Global variable to store the engine instance. This is initialized in the main function.
def init_global_engine(args, engine_config):
global engine
logger.debug(f"Received args: {args}")
logger.info(f"Initializing global engine with engine config: {engine_config}")
trt_llm_engine_config = TensorrtLLMEngineConfig(
engine_config=engine_config,
)
engine = DynamoTRTLLMEngine(trt_llm_engine_config)
@dynamo_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate(request):
async for response in chat_generator(engine, request):
yield response
if __name__ == "__main__":
args, engine_config = parse_dynamo_run_args()
init_global_engine(args, engine_config)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import sys
from pathlib import Path
import uvloop
# Add the project root to the Python path
project_root = str(Path(__file__).parents[1]) # Go up to trtllm directory
if project_root not in sys.path:
sys.path.append(project_root)
from common.parser import parse_tensorrt_llm_args # noqa: E402
from .worker import trtllm_worker # noqa: E402
if __name__ == "__main__":
uvloop.install()
args, engine_config = parse_tensorrt_llm_args()
asyncio.run(trtllm_worker(engine_config))
......@@ -13,17 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import signal
import uuid
import uvloop
from common.base_engine import BaseTensorrtLLMEngine, TensorrtLLMEngineConfig
from common.parser import LLMAPIConfig, parse_tensorrt_llm_args
from common.processor import merge_promises, parse_chat_message_content
from tensorrt_llm.executor import CppExecutorError
from common.generators import chat_generator, completion_generator
from common.parser import LLMAPIConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionRequest,
......@@ -47,101 +41,17 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine):
@dynamo_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate_chat(self, request):
if self._llm_engine is None:
raise RuntimeError("Engine not initialized")
logger.debug(f"Received chat request: {request}")
request_id = str(uuid.uuid4())
self._ongoing_request_count += 1
try:
conversation = []
for message in request.messages:
conversation.extend(parse_chat_message_content(message))
tool_dicts = (
None
if request.tools is None
else [tool.model_dump() for tool in request.tools]
)
prompt: str = self._tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
**(request.chat_template_kwargs or {}),
)
sampling_params = request.to_sampling_params()
promise = self._llm_engine.generate_async(
prompt,
sampling_params,
streaming=request.stream,
)
# NOTE: somehow stream and non-stream is working with the same path
response_generator = self.chat_processor.stream_response(
request, request_id, conversation, promise
)
async for response in response_generator:
async for response in chat_generator(self, request):
yield response
self._ongoing_request_count -= 1
except CppExecutorError:
# If internal executor error is raised, shutdown the server
signal.raise_signal(signal.SIGINT)
except Exception as e:
raise RuntimeError("Failed to generate: " + str(e))
@dynamo_endpoint(CompletionRequest, CompletionStreamResponse)
async def generate_completion(self, request):
if self._llm_engine is None:
raise RuntimeError("Engine not initialized")
self._ongoing_request_count += 1
logger.debug(f"Received completion request: {request}")
if isinstance(request.prompt, str) or (
isinstance(request.prompt, list) and isinstance(request.prompt[0], int)
):
prompts = [request.prompt]
else:
prompts = request.prompt
promises = []
sampling_params = request.to_sampling_params()
try:
for prompt in prompts:
promise = self._llm_engine.generate_async(
prompt,
sampling_params,
streaming=request.stream,
)
promises.append(promise)
generator = merge_promises(promises)
num_choices = (
len(prompts) if request.n is None else len(prompts) * request.n
)
# NOTE: always send `stream: true` to the worker, and decide whether to aggregate or not before sending the response back to client.
response_generator = self.completions_processor.create_completion_generator(
request, generator, num_choices
)
async for response in response_generator:
yield json.loads(response)
self._ongoing_request_count -= 1
except CppExecutorError:
# If internal executor error is raised, shutdown the server
signal.raise_signal(signal.SIGINT)
except Exception as e:
raise RuntimeError("Failed to generate: " + str(e))
async for response in completion_generator(self, request):
yield response
@dynamo_worker()
async def worker(runtime: DistributedRuntime, engine_config: LLMAPIConfig):
async def trtllm_worker(runtime: DistributedRuntime, engine_config: LLMAPIConfig):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
......@@ -166,9 +76,3 @@ async def worker(runtime: DistributedRuntime, engine_config: LLMAPIConfig):
completions_endpoint.serve_endpoint(engine.generate_completion),
chat_completions_endpoint.serve_endpoint(engine.generate_chat),
)
if __name__ == "__main__":
uvloop.install()
args, engine_config = parse_tensorrt_llm_args()
asyncio.run(worker(engine_config))
......@@ -41,15 +41,32 @@ use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine
/// Python snippet to import a file as a module
const PY_IMPORT: &CStr = cr#"
import importlib.util
import runpy
import sys
import os
import functools
import types
spec = importlib.util.spec_from_file_location("__main__", file_path)
module = importlib.util.module_from_spec(spec)
module_dir = os.path.dirname(file_path)
if module_dir not in sys.path:
sys.path.insert(0, module_dir)
sys.argv = sys_argv
sys.modules["__main__"] = module
spec.loader.exec_module(module)
module_dict = runpy.run_path(file_path, run_name='__main__')
# Create a module class with the generate function
class Module:
def __init__(self, module_dict):
self.__dict__.update(module_dict)
self._generate_func = module_dict['generate']
async def generate(self, request):
async for response in self._generate_func(request):
yield response
# Create module instance and store it in globals
module = Module(module_dict)
globals()['module'] = module
"#;
/// An engine that takes and returns strings, feeding them to a python written engine
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment