Unverified Commit 222245e2 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

refactor: Move engine and publisher from dynamo.llm.tensorrt_llm to dynamo.trtllm (#2128)

parent 4498a77d
...@@ -11,15 +11,12 @@ from tensorrt_llm import SamplingParams ...@@ -11,15 +11,12 @@ from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from dynamo.llm import ( from dynamo.llm import ModelType, register_llm
ModelType,
get_tensorrtllm_engine,
get_tensorrtllm_publisher,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.utils.request_handlers.handlers import ( from dynamo.trtllm.engine import get_llm_engine
from dynamo.trtllm.publisher import get_publisher
from dynamo.trtllm.request_handlers.handlers import (
RequestHandlerConfig, RequestHandlerConfig,
RequestHandlerFactory, RequestHandlerFactory,
) )
...@@ -129,7 +126,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -129,7 +126,7 @@ async def init(runtime: DistributedRuntime, config: Config):
# We already detokenize inside HandlerBase. No need to also do it in TRTLLM. # We already detokenize inside HandlerBase. No need to also do it in TRTLLM.
default_sampling_params.detokenize = False default_sampling_params.detokenize = False
async with get_tensorrtllm_engine(engine_args) as engine: async with get_llm_engine(engine_args) as engine:
endpoint = component.endpoint(config.endpoint) endpoint = component.endpoint(config.endpoint)
if is_first_worker(config): if is_first_worker(config):
...@@ -159,7 +156,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -159,7 +156,7 @@ async def init(runtime: DistributedRuntime, config: Config):
kv_listener = runtime.namespace(config.namespace).component( kv_listener = runtime.namespace(config.namespace).component(
config.component config.component
) )
async with get_tensorrtllm_publisher( async with get_publisher(
component, component,
engine, engine,
kv_listener, kv_listener,
......
...@@ -9,7 +9,7 @@ import traceback ...@@ -9,7 +9,7 @@ import traceback
import weakref import weakref
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from queue import Queue from queue import Queue
from typing import Callable, Optional, Union from typing import Awaitable, Callable, Optional, Union
from dynamo.llm import ( from dynamo.llm import (
ForwardPassMetrics, ForwardPassMetrics,
...@@ -41,7 +41,7 @@ class ManagedThread(threading.Thread): ...@@ -41,7 +41,7 @@ class ManagedThread(threading.Thread):
def __init__( def __init__(
self, self,
task: Optional[Union[Callable[..., bool], weakref.WeakMethod]], task: Optional[Union[Callable[..., Awaitable[bool]], weakref.WeakMethod]],
error_queue: Optional[Queue] = None, error_queue: Optional[Queue] = None,
name: Optional[str] = None, name: Optional[str] = None,
loop: Optional[asyncio.AbstractEventLoop] = None, loop: Optional[asyncio.AbstractEventLoop] = None,
...@@ -62,7 +62,9 @@ class ManagedThread(threading.Thread): ...@@ -62,7 +62,9 @@ class ManagedThread(threading.Thread):
def run(self): def run(self):
while not self._stop_event.is_set(): while not self._stop_event.is_set():
task: Optional[Union[Callable[..., bool], weakref.WeakMethod]] = self.task task: Optional[
Union[Callable[..., Awaitable[bool]], weakref.WeakMethod]
] = self.task
if isinstance(task, weakref.WeakMethod): if isinstance(task, weakref.WeakMethod):
task = task() task = task()
if task is None: if task is None:
...@@ -77,9 +79,14 @@ class ManagedThread(threading.Thread): ...@@ -77,9 +79,14 @@ class ManagedThread(threading.Thread):
if self.loop is None: if self.loop is None:
logging.error("[ManagedThread] Loop not initialized!") logging.error("[ManagedThread] Loop not initialized!")
break break
self._current_future = asyncio.run_coroutine_threadsafe(
task(**self.kwargs), self.loop # Call the task function to get the coroutine
) coro = task(**self.kwargs)
if not asyncio.iscoroutine(coro):
logging.error(f"Task {task} did not return a coroutine")
break
self._current_future = asyncio.run_coroutine_threadsafe(coro, self.loop)
_ = self._current_future.result() _ = self._current_future.result()
except (asyncio.CancelledError, concurrent.futures.CancelledError): except (asyncio.CancelledError, concurrent.futures.CancelledError):
logging.debug(f"Thread {self.name} was cancelled") logging.debug(f"Thread {self.name} was cancelled")
......
...@@ -20,9 +20,9 @@ from enum import Enum ...@@ -20,9 +20,9 @@ from enum import Enum
from tensorrt_llm import SamplingParams from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from dynamo.llm.tensorrtllm.engine import TensorRTLLMEngine
from dynamo.llm.tensorrtllm.publisher import Publisher
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import TensorRTLLMEngine
from dynamo.trtllm.publisher import Publisher
from dynamo.trtllm.utils.disagg_utils import ( from dynamo.trtllm.utils.disagg_utils import (
DisaggregatedParams, DisaggregatedParams,
DisaggregatedParamsCodec, DisaggregatedParamsCodec,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import copy import copy
from dynamo.trtllm.utils.request_handlers.handler_base import ( from dynamo.trtllm.request_handlers.handler_base import (
DisaggregationMode, DisaggregationMode,
DisaggregationStrategy, DisaggregationStrategy,
HandlerBase, HandlerBase,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import argparse import argparse
from typing import Optional from typing import Optional
from dynamo.trtllm.utils.request_handlers.handler_base import ( from dynamo.trtllm.request_handlers.handler_base import (
DisaggregationMode, DisaggregationMode,
DisaggregationStrategy, DisaggregationStrategy,
) )
......
...@@ -41,17 +41,3 @@ from dynamo._core import compute_block_hash_for_seq_py as compute_block_hash_for ...@@ -41,17 +41,3 @@ from dynamo._core import compute_block_hash_for_seq_py as compute_block_hash_for
from dynamo._core import make_engine from dynamo._core import make_engine
from dynamo._core import register_llm as register_llm from dynamo._core import register_llm as register_llm
from dynamo._core import run_input from dynamo._core import run_input
try:
from dynamo.llm.tensorrtllm import ( # noqa: F401
get_llm_engine as get_tensorrtllm_engine,
)
from dynamo.llm.tensorrtllm import ( # noqa: F401
get_publisher as get_tensorrtllm_publisher,
)
except ImportError:
pass # TensorRTLLM is not enabled by default
except Exception as e:
# Don't let TensorRTLLM break other engines
logger = logging.getLogger(__name__)
logger.exception(f"Error importing TensorRT-LLM components: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .engine import get_llm_engine # noqa: F401
from .publisher import get_publisher # noqa: F401
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment