Unverified Commit 222245e2 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

refactor: Move engine and publisher from dynamo.llm.tensorrt_llm to dynamo.trtllm (#2128)

parent 4498a77d
......@@ -11,15 +11,12 @@ from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from dynamo.llm import (
ModelType,
get_tensorrtllm_engine,
get_tensorrtllm_publisher,
register_llm,
)
from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.utils.request_handlers.handlers import (
from dynamo.trtllm.engine import get_llm_engine
from dynamo.trtllm.publisher import get_publisher
from dynamo.trtllm.request_handlers.handlers import (
RequestHandlerConfig,
RequestHandlerFactory,
)
......@@ -129,7 +126,7 @@ async def init(runtime: DistributedRuntime, config: Config):
# We already detokenize inside HandlerBase. No need to also do it in TRTLLM.
default_sampling_params.detokenize = False
async with get_tensorrtllm_engine(engine_args) as engine:
async with get_llm_engine(engine_args) as engine:
endpoint = component.endpoint(config.endpoint)
if is_first_worker(config):
......@@ -159,7 +156,7 @@ async def init(runtime: DistributedRuntime, config: Config):
kv_listener = runtime.namespace(config.namespace).component(
config.component
)
async with get_tensorrtllm_publisher(
async with get_publisher(
component,
engine,
kv_listener,
......
......@@ -9,7 +9,7 @@ import traceback
import weakref
from contextlib import asynccontextmanager
from queue import Queue
from typing import Callable, Optional, Union
from typing import Awaitable, Callable, Optional, Union
from dynamo.llm import (
ForwardPassMetrics,
......@@ -41,7 +41,7 @@ class ManagedThread(threading.Thread):
def __init__(
self,
task: Optional[Union[Callable[..., bool], weakref.WeakMethod]],
task: Optional[Union[Callable[..., Awaitable[bool]], weakref.WeakMethod]],
error_queue: Optional[Queue] = None,
name: Optional[str] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
......@@ -62,7 +62,9 @@ class ManagedThread(threading.Thread):
def run(self):
while not self._stop_event.is_set():
task: Optional[Union[Callable[..., bool], weakref.WeakMethod]] = self.task
task: Optional[
Union[Callable[..., Awaitable[bool]], weakref.WeakMethod]
] = self.task
if isinstance(task, weakref.WeakMethod):
task = task()
if task is None:
......@@ -77,9 +79,14 @@ class ManagedThread(threading.Thread):
if self.loop is None:
logging.error("[ManagedThread] Loop not initialized!")
break
self._current_future = asyncio.run_coroutine_threadsafe(
task(**self.kwargs), self.loop
)
# Call the task function to get the coroutine
coro = task(**self.kwargs)
if not asyncio.iscoroutine(coro):
logging.error(f"Task {task} did not return a coroutine")
break
self._current_future = asyncio.run_coroutine_threadsafe(coro, self.loop)
_ = self._current_future.result()
except (asyncio.CancelledError, concurrent.futures.CancelledError):
logging.debug(f"Thread {self.name} was cancelled")
......
......@@ -20,9 +20,9 @@ from enum import Enum
from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from dynamo.llm.tensorrtllm.engine import TensorRTLLMEngine
from dynamo.llm.tensorrtllm.publisher import Publisher
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import TensorRTLLMEngine
from dynamo.trtllm.publisher import Publisher
from dynamo.trtllm.utils.disagg_utils import (
DisaggregatedParams,
DisaggregatedParamsCodec,
......
......@@ -3,7 +3,7 @@
import copy
from dynamo.trtllm.utils.request_handlers.handler_base import (
from dynamo.trtllm.request_handlers.handler_base import (
DisaggregationMode,
DisaggregationStrategy,
HandlerBase,
......
......@@ -4,7 +4,7 @@
import argparse
from typing import Optional
from dynamo.trtllm.utils.request_handlers.handler_base import (
from dynamo.trtllm.request_handlers.handler_base import (
DisaggregationMode,
DisaggregationStrategy,
)
......
......@@ -41,17 +41,3 @@ from dynamo._core import compute_block_hash_for_seq_py as compute_block_hash_for
from dynamo._core import make_engine
from dynamo._core import register_llm as register_llm
from dynamo._core import run_input
try:
from dynamo.llm.tensorrtllm import ( # noqa: F401
get_llm_engine as get_tensorrtllm_engine,
)
from dynamo.llm.tensorrtllm import ( # noqa: F401
get_publisher as get_tensorrtllm_publisher,
)
except ImportError:
pass # TensorRTLLM is not enabled by default
except Exception as e:
# Don't let TensorRTLLM break other engines
logger = logging.getLogger(__name__)
logger.exception(f"Error importing TensorRT-LLM components: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .engine import get_llm_engine # noqa: F401
from .publisher import get_publisher # noqa: F401
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment