# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import threading import traceback import weakref from queue import Queue from typing import Callable, Optional, Union from tensorrt_llm.logger import logger logger.set_level("info") class ManagedThread(threading.Thread): def __init__( self, task: Optional[Union[Callable[..., bool], weakref.WeakMethod]], error_queue: Optional[Queue] = None, name: Optional[str] = None, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs, ): super().__init__(name=name) self.task = task self.error_queue = error_queue self.kwargs = kwargs self.loop = loop self.daemon = True self.stop_event = threading.Event() def set_loop(self, loop: asyncio.AbstractEventLoop): self.loop = loop def run(self): while not self.stop_event.is_set(): task: Optional[Union[Callable[..., bool], weakref.WeakMethod]] = self.task if isinstance(task, weakref.WeakMethod): task = task() if task is None: # Normally, this should not happen. logger.warning("WeakMethod is expired.") break if task is None: break try: if self.loop is None: logger.error("[ManagedThread] Loop not initialized!") break future = asyncio.run_coroutine_threadsafe( task(**self.kwargs), self.loop ) _ = future.result() except Exception as e: logger.error( f"Error in thread {self.name}: {e}\n{traceback.format_exc()}" ) if self.error_queue is not None: self.error_queue.put(e) logger.info(f"Thread {self.name} stopped.") def stop(self): self.stop_event.set()