Commit 022b6db5 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

refactor: Improve the usage of logging in worker (#29)

parent 1c1bd7da
...@@ -16,14 +16,14 @@ ...@@ -16,14 +16,14 @@
from frontend.fastapi_frontend import FastApiFrontend from frontend.fastapi_frontend import FastApiFrontend
from llm.api_server.triton_distributed_engine import TritonDistributedEngine from llm.api_server.triton_distributed_engine import TritonDistributedEngine
from triton_distributed.worker.log_formatter import setup_logger from triton_distributed.worker.logger import get_logger
from .parser import parse_args from .parser import parse_args
def main(args): def main(args):
print(args) print(args)
logger = setup_logger(args.log_level, args.program_name) logger = get_logger(args.log_level, args.program_name)
logger.info("Starting") logger.info("Starting")
......
...@@ -9,8 +9,6 @@ def parse_args(): ...@@ -9,8 +9,6 @@ def parse_args():
# default_log_dir = "" example_dir.joinpath("logs") # default_log_dir = "" example_dir.joinpath("logs")
default_log_dir = "" default_log_dir = ""
parser = argparse.ArgumentParser(description="Hello World Deployment")
parser.add_argument( parser.add_argument(
"--log-dir", "--log-dir",
type=str, type=str,
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from triton_distributed.worker.deployment import Deployment as Deployment from triton_distributed.worker.deployment import Deployment as Deployment
from triton_distributed.worker.logger import get_logger as get_logger
from triton_distributed.worker.logger import get_logger_config as get_logger_config
from triton_distributed.worker.operator import Operator as Operator from triton_distributed.worker.operator import Operator as Operator
from triton_distributed.worker.operator import OperatorConfig as OperatorConfig from triton_distributed.worker.operator import OperatorConfig as OperatorConfig
from triton_distributed.worker.remote_operator import RemoteOperator as RemoteOperator from triton_distributed.worker.remote_operator import RemoteOperator as RemoteOperator
......
...@@ -25,7 +25,7 @@ from triton_distributed.icp import ( ...@@ -25,7 +25,7 @@ from triton_distributed.icp import (
RequestPlane, RequestPlane,
UcpDataPlane, UcpDataPlane,
) )
from triton_distributed.worker.log_formatter import setup_logger from triton_distributed.worker.logger import get_logger
from triton_distributed.worker.worker import Worker, WorkerConfig from triton_distributed.worker.worker import Worker, WorkerConfig
LOGGER_NAME = __name__ LOGGER_NAME = __name__
...@@ -43,12 +43,13 @@ class Deployment: ...@@ -43,12 +43,13 @@ class Deployment:
data_plane: Optional[Type[DataPlane]] = UcpDataPlane, data_plane: Optional[Type[DataPlane]] = UcpDataPlane,
data_plane_args: Optional[tuple[list, dict]] = None, data_plane_args: Optional[tuple[list, dict]] = None,
log_dir="logs", log_dir="logs",
consolidate_logs=False,
starting_metrics_port=0, starting_metrics_port=0,
): ):
self._process_context = multiprocessing.get_context("spawn") self._process_context = multiprocessing.get_context("spawn")
self._worker_configs = worker_configs self._worker_configs = worker_configs
self._workers: list[multiprocessing.context.SpawnProcess] = [] self._workers: list[multiprocessing.context.SpawnProcess] = []
self._logger = setup_logger(log_level, LOGGER_NAME) self._logger = get_logger(log_level, LOGGER_NAME)
self._default_request_plane = request_plane self._default_request_plane = request_plane
self._default_request_plane_args = request_plane_args self._default_request_plane_args = request_plane_args
self._default_data_plane = data_plane self._default_data_plane = data_plane
...@@ -58,6 +59,7 @@ class Deployment: ...@@ -58,6 +59,7 @@ class Deployment:
self.request_plane_server: NatsServer = None self.request_plane_server: NatsServer = None
self._default_log_dir = log_dir self._default_log_dir = log_dir
self._default_log_level = log_level self._default_log_level = log_level
self._consolidate_logs = consolidate_logs
self._starting_metrics_port = starting_metrics_port self._starting_metrics_port = starting_metrics_port
@staticmethod @staticmethod
...@@ -103,6 +105,9 @@ class Deployment: ...@@ -103,6 +105,9 @@ class Deployment:
if not worker_config.log_level: if not worker_config.log_level:
worker_config.log_level = self._default_log_level worker_config.log_level = self._default_log_level
if self._consolidate_logs:
worker_config.consolidate_logs = True
for index in range(worker_instances): for index in range(worker_instances):
worker_config.name = f"{base_name}.{index}" worker_config.name = f"{base_name}.{index}"
worker_config.metrics_port = base_port + index worker_config.metrics_port = base_port + index
......
...@@ -14,40 +14,69 @@ ...@@ -14,40 +14,69 @@
# limitations under the License. # limitations under the License.
import logging import logging
import sys import logging.config
from typing import Any
LOGGER_NAME = "Triton Worker" _LOGGER_NAME = "Triton Distributed Worker"
_FHANDLER_CONFIG_TEMPLATE = {
"class": "logging.FileHandler",
"formatter": "standard",
}
class LogFormatter(logging.Formatter): _LOGGER_CONFIG_TEMPLATE = {"handlers": ["console"], "propagate": True}
"""Class to handle formatting of the logger outputs"""
def __init__(self, logger_name=LOGGER_NAME): _LOGGING_CONFIG_TEMPLATE = {
logger = logging.getLogger(logger_name) "version": 1,
self._log_level = logger.getEffectiveLevel() "disable_existing_loggers": False,
self._logger_name = logger_name "formatters": {
super().__init__(datefmt="%H:%M:%S") "standard": {
"format": "%(asctime)s - %(levelname)s - %(name)s - %(message)s",
"datefmt": "%H:%M:%S",
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "standard",
"stream": "ext://sys.stdout",
}
},
}
def get_logger_config(log_level=1, logger_name=_LOGGER_NAME, log_file=None):
config_dict: dict[str, Any] = _LOGGING_CONFIG_TEMPLATE
front = "%(asctime)s.%(msecs)03d %(filename)s:%(lineno)s"
config_dict["formatters"]["standard"][
"format"
] = f"{front} [{logger_name}] %(levelname)s: %(message)s"
if log_file:
fh_config_dict = _FHANDLER_CONFIG_TEMPLATE
fh_config_dict["filename"] = str(log_file)
config_dict["handlers"]["file"] = fh_config_dict
logger_config: dict[str, Any] = _LOGGER_CONFIG_TEMPLATE
if log_file:
logger_config["handlers"].append("file")
def format(self, record): config_dict["loggers"] = {}
front = "%(asctime)s %(filename)s:%(lineno)s" config_dict["loggers"][logger_name] = logger_config
self._style._fmt = f"{front}[{self._logger_name}] %(levelname)s: %(message)s"
return super().format(record)
return config_dict
def setup_logger(log_level=1, logger_name=LOGGER_NAME):
# TODO: Add support for taking logging level as input as well.
def get_logger(log_level=1, logger_name=_LOGGER_NAME, log_file=None):
if log_level == 0: if log_level == 0:
log_level = logging.ERROR level = logging.ERROR
elif log_level == 1: elif log_level == 1:
log_level = logging.INFO level = logging.INFO
else: else:
log_level = logging.DEBUG level = logging.DEBUG
config_dict = get_logger_config(log_level, logger_name, log_file)
logging.config.dictConfig(config_dict)
logger = logging.getLogger(logger_name) logger = logging.getLogger(logger_name)
logger.setLevel(level=log_level) logger.setLevel(level=level)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(LogFormatter(logger_name=logger_name))
logger.addHandler(handler)
logger.propagate = True
return logger return logger
...@@ -39,8 +39,11 @@ from triton_distributed.icp.request_plane import ( ...@@ -39,8 +39,11 @@ from triton_distributed.icp.request_plane import (
set_icp_final_response, set_icp_final_response,
set_icp_response_error, set_icp_response_error,
) )
from triton_distributed.worker.logger import get_logger
from triton_distributed.worker.remote_tensor import RemoteTensor from triton_distributed.worker.remote_tensor import RemoteTensor
logger = get_logger(__name__)
class AsyncRemoteResponseIterator: class AsyncRemoteResponseIterator:
...@@ -185,7 +188,7 @@ class AsyncRemoteResponseIterator: ...@@ -185,7 +188,7 @@ class AsyncRemoteResponseIterator:
except Exception as e: except Exception as e:
message = f"Catastrophic failure in response callback: {e}" message = f"Catastrophic failure in response callback: {e}"
print(message) logger.exception(message)
# catastrophic failure # catastrophic failure
raise e from None raise e from None
......
...@@ -26,6 +26,7 @@ from tritonserver import InvalidArgumentError, Server ...@@ -26,6 +26,7 @@ from tritonserver import InvalidArgumentError, Server
from triton_distributed.icp.data_plane import DataPlane from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.request_plane import RequestPlane from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.worker.logger import get_logger
from triton_distributed.worker.operator import Operator from triton_distributed.worker.operator import Operator
from triton_distributed.worker.remote_request import RemoteInferenceRequest from triton_distributed.worker.remote_request import RemoteInferenceRequest
from triton_distributed.worker.remote_response import RemoteInferenceResponse from triton_distributed.worker.remote_response import RemoteInferenceResponse
...@@ -41,7 +42,7 @@ class TritonCoreOperator(Operator): ...@@ -41,7 +42,7 @@ class TritonCoreOperator(Operator):
data_plane: DataPlane, data_plane: DataPlane,
parameters: dict, parameters: dict,
repository: Optional[str] = None, repository: Optional[str] = None,
logger: logging.Logger = logging.getLogger(), logger: logging.Logger = get_logger(__name__),
): ):
self._repository = repository self._repository = repository
self._name = name self._name = name
...@@ -93,7 +94,7 @@ class TritonCoreOperator(Operator): ...@@ -93,7 +94,7 @@ class TritonCoreOperator(Operator):
request_id_map = {} request_id_map = {}
response_queue: asyncio.Queue = asyncio.Queue() response_queue: asyncio.Queue = asyncio.Queue()
for request in requests: for request in requests:
self._logger.info("\n\nReceived request: \n\n%s\n\n", request) self._logger.debug("\n\nReceived request: \n\n%s\n\n", request)
try: try:
local_request = request.to_local_request(self._local_model) local_request = request.to_local_request(self._local_model)
except Exception as e: except Exception as e:
...@@ -126,5 +127,5 @@ class TritonCoreOperator(Operator): ...@@ -126,5 +127,5 @@ class TritonCoreOperator(Operator):
if local_response.final: if local_response.final:
del request_id_map[local_response.request_id] del request_id_map[local_response.request_id]
self._logger.info("\n\nSending response\n\n%s\n\n", remote_response) self._logger.debug("\n\nSending response\n\n%s\n\n", remote_response)
await response_sender.send(remote_response) await response_sender.send(remote_response)
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import asyncio import asyncio
import importlib import importlib
import logging
import os import os
import pathlib import pathlib
import signal import signal
...@@ -31,7 +30,7 @@ from triton_distributed.icp.data_plane import DataPlane ...@@ -31,7 +30,7 @@ from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.nats_request_plane import NatsRequestPlane from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.request_plane import RequestPlane from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.worker.log_formatter import LOGGER_NAME, setup_logger from triton_distributed.worker.logger import get_logger, get_logger_config
from triton_distributed.worker.operator import Operator, OperatorConfig from triton_distributed.worker.operator import Operator, OperatorConfig
from triton_distributed.worker.remote_request import ( from triton_distributed.worker.remote_request import (
RemoteInferenceRequest, RemoteInferenceRequest,
...@@ -42,7 +41,7 @@ from triton_distributed.worker.triton_core_operator import TritonCoreOperator ...@@ -42,7 +41,7 @@ from triton_distributed.worker.triton_core_operator import TritonCoreOperator
if TYPE_CHECKING: if TYPE_CHECKING:
import uvicorn import uvicorn
logger = logging.getLogger(LOGGER_NAME) logger = get_logger(__name__)
@dataclass @dataclass
...@@ -53,9 +52,9 @@ class WorkerConfig: ...@@ -53,9 +52,9 @@ class WorkerConfig:
data_plane_args: tuple[list, dict] = field(default_factory=lambda: ([], {})) data_plane_args: tuple[list, dict] = field(default_factory=lambda: ([], {}))
log_level: Optional[int] = None log_level: Optional[int] = None
operators: list[OperatorConfig] = field(default_factory=list) operators: list[OperatorConfig] = field(default_factory=list)
triton_log_path: Optional[str] = None
name: str = str(uuid.uuid1()) name: str = str(uuid.uuid1())
log_dir: Optional[str] = None log_dir: Optional[str] = None
consolidate_logs = False
metrics_port: int = 0 metrics_port: int = 0
...@@ -73,14 +72,13 @@ class Worker: ...@@ -73,14 +72,13 @@ class Worker:
self._data_plane = config.data_plane( self._data_plane = config.data_plane(
*config.data_plane_args[0], **config.data_plane_args[1] *config.data_plane_args[0], **config.data_plane_args[1]
) )
self._triton_log_path = config.triton_log_path
self._name = config.name self._name = config.name
self._log_level = config.log_level self._log_level = config.log_level
if self._log_level is None: if self._log_level is None:
self._log_level = 0 self._log_level = 0
self._operator_configs = config.operators self._operator_configs = config.operators
self._log_dir = config.log_dir self._log_dir = config.log_dir
self._consolidate_logs = config.consolidate_logs
self._stop_requested = False self._stop_requested = False
self._requests_received: Counter = Counter() self._requests_received: Counter = Counter()
self._background_tasks: dict[object, set] = {} self._background_tasks: dict[object, set] = {}
...@@ -92,6 +90,12 @@ class Worker: ...@@ -92,6 +90,12 @@ class Worker:
self._metrics_server: Optional[uvicorn.Server] = None self._metrics_server: Optional[uvicorn.Server] = None
self._component_id = self._request_plane.component_id self._component_id = self._request_plane.component_id
self._triton_core: Optional[tritonserver.Server] = None self._triton_core: Optional[tritonserver.Server] = None
self._log_file: Optional[pathlib.Path] = None
if self._log_dir:
path = pathlib.Path(self._log_dir)
path.mkdir(parents=True, exist_ok=True)
pid = os.getpid()
self._log_file = path / f"{self._name}.{self._component_id}.{pid}.log"
def _import_operators(self): def _import_operators(self):
for operator_config in self._operator_configs: for operator_config in self._operator_configs:
...@@ -134,22 +138,32 @@ class Worker: ...@@ -134,22 +138,32 @@ class Worker:
try: try:
if operator_config.log_level is None: if operator_config.log_level is None:
operator_config.log_level = self._log_level operator_config.log_level = self._log_level
operator_logger = setup_logger( operator_logger = get_logger(
log_level=operator_config.log_level, log_level=operator_config.log_level,
logger_name=f"OPERATOR{(operator_config.name,operator_config.version)}", logger_name=f"OPERATOR{(operator_config.name,operator_config.version)}",
log_file=self._log_file,
) )
if ( if (
class_ == TritonCoreOperator class_ == TritonCoreOperator
or issubclass(class_, TritonCoreOperator) or issubclass(class_, TritonCoreOperator)
) and not self._triton_core: ) and not self._triton_core:
if not self._consolidate_logs and self._log_file:
log_file = pathlib.Path(self._log_file)
stem = log_file.stem
suffix = log_file.suffix
triton_log_path = str(
log_file.parent / f"{stem}.triton{suffix}"
)
else:
triton_log_path = str(self._log_file)
self._triton_core = tritonserver.Server( self._triton_core = tritonserver.Server(
model_repository=".", model_repository=".",
log_error=True, log_error=True,
log_verbose=self._log_level, log_verbose=self._log_level,
strict_model_config=False, strict_model_config=False,
model_control_mode=tritonserver.ModelControlMode.EXPLICIT, model_control_mode=tritonserver.ModelControlMode.EXPLICIT,
log_file=self._triton_log_path, log_file=triton_log_path,
).start(wait_until_ready=True) ).start(wait_until_ready=True)
operator = class_( operator = class_(
...@@ -176,7 +190,7 @@ class Worker: ...@@ -176,7 +190,7 @@ class Worker:
self._completion_conds[operator] = asyncio.Condition() self._completion_conds[operator] = asyncio.Condition()
async def _process_request(self, request): async def _process_request(self, request):
logger.info("\n\nserver received request: \n\n%s\n\n", request) logger.debug("\n\nserver received request: \n\n%s\n\n", request)
operator_key = (request.model_name, int(request.model_version)) operator_key = (request.model_name, int(request.model_version))
...@@ -188,7 +202,7 @@ class Worker: ...@@ -188,7 +202,7 @@ class Worker:
) )
await operator.execute([remote_request]) await operator.execute([remote_request])
else: else:
logger.warn("Received request for unknown operator") logger.warning("Received request for unknown operator")
async def _process_request_task(self, operator, name, version): async def _process_request_task(self, operator, name, version):
requests = await self._request_plane.pull_requests(name, str(version)) requests = await self._request_plane.pull_requests(name, str(version))
...@@ -244,7 +258,6 @@ class Worker: ...@@ -244,7 +258,6 @@ class Worker:
await asyncio.gather(*handlers) await asyncio.gather(*handlers)
async def serve(self): async def serve(self):
error = None
try: try:
await self._request_plane.connect() await self._request_plane.connect()
except Exception as e: except Exception as e:
...@@ -260,7 +273,7 @@ class Worker: ...@@ -260,7 +273,7 @@ class Worker:
"Encountered and error when trying to connect to data plane" "Encountered and error when trying to connect to data plane"
) )
raise e raise e
error = None
try: try:
self._import_operators() self._import_operators()
logger.info("Worker started...") logger.info("Worker started...")
...@@ -317,7 +330,17 @@ class Worker: ...@@ -317,7 +330,17 @@ class Worker:
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
app = FastAPI() app = FastAPI()
config = uvicorn.Config(app, port=self._metrics_port) log_config = get_logger_config(
logger_name="uvicorn.error",
log_level=self._log_level,
log_file=self._log_file,
)
config = uvicorn.Config(
app,
port=self._metrics_port,
log_level=self._log_level,
log_config=log_config,
)
server = uvicorn.Server(config) server = uvicorn.Server(config)
@app.get("/metrics", response_class=PlainTextResponse) @app.get("/metrics", response_class=PlainTextResponse)
...@@ -329,6 +352,14 @@ class Worker: ...@@ -329,6 +352,14 @@ class Worker:
return server return server
@staticmethod
def exception_handler(loop, context):
# get details of the exception
exception = context["exception"]
message = context["message"]
# log exception
logger.error(f"Task failed, msg={message}, exception={exception}")
async def _wait_for_tasks(self, loop): async def _wait_for_tasks(self, loop):
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
try: try:
...@@ -342,26 +373,10 @@ class Worker: ...@@ -342,26 +373,10 @@ class Worker:
def start(self): def start(self):
exit_condition = None exit_condition = None
logger = get_logger(log_level=self._log_level, log_file=self._log_file)
if self._log_dir: logger.info(f"Starting Worker ==> {self._name}")
pid = os.getpid()
os.makedirs(self._log_dir, exist_ok=True)
stdout_path = os.path.join(
self._log_dir, f"{self._name}.{self._component_id}.{pid}.stdout.log"
)
stderr_path = os.path.join(
self._log_dir, f"{self._name}.{self._component_id}.{pid}.stderr.log"
)
if not self._triton_log_path:
self._triton_log_path = os.path.join(
self._log_dir, f"{self._name}.{self._component_id}.{pid}.triton.log"
)
sys.stdout = open(stdout_path, "w", buffering=1)
sys.stderr = open(stderr_path, "w", buffering=1)
triton_log = open(self._triton_log_path, "w", buffering=1)
triton_log.close()
setup_logger(log_level=self._log_level)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.set_exception_handler(Worker.exception_handler)
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
# Note: mypy has known issues inferring # Note: mypy has known issues inferring
...@@ -391,13 +406,6 @@ class Worker: ...@@ -391,13 +406,6 @@ class Worker:
else: else:
exit_condition = serve_result exit_condition = serve_result
sys.stdout.flush()
sys.stderr.flush()
if self._log_dir:
sys.stdout.close()
sys.stderr.close()
if exit_condition is not None: if exit_condition is not None:
sys.exit(1) sys.exit(1)
else: else:
......
...@@ -23,9 +23,8 @@ import pytest ...@@ -23,9 +23,8 @@ import pytest
import pytest_asyncio import pytest_asyncio
from triton_distributed.icp.nats_request_plane import NatsServer from triton_distributed.icp.nats_request_plane import NatsServer
from triton_distributed.worker.log_formatter import LOGGER_NAME
logger = logging.getLogger(LOGGER_NAME) logger = logging.getLogger(__name__)
NATS_PORT = 4223 NATS_PORT = 4223
......
...@@ -43,19 +43,20 @@ class AddMultiplyDivide(Operator): ...@@ -43,19 +43,20 @@ class AddMultiplyDivide(Operator):
self._divide_model = RemoteOperator( self._divide_model = RemoteOperator(
"divide", self._request_plane, self._data_plane "divide", self._request_plane, self._data_plane
) )
self._logger = logger
async def execute(self, requests: list[RemoteInferenceRequest]): async def execute(self, requests: list[RemoteInferenceRequest]):
print("in execute!", flush=True) self._logger.debug("in execute!")
for request in requests: for request in requests:
outputs = {} outputs = {}
print(request.inputs, flush=True) self._logger.debug(request.inputs)
array = None array = None
try: try:
array = numpy.from_dlpack(request.inputs["int64_input"]) array = numpy.from_dlpack(request.inputs["int64_input"])
except Exception as e: except Exception:
print(e) self._logger.exception("Failed to retrieve inputs")
print(array) self._logger.debug(array)
response = [ response = [
response response
async for response in await self._add_model.async_infer( async for response in await self._add_model.async_infer(
...@@ -63,7 +64,7 @@ class AddMultiplyDivide(Operator): ...@@ -63,7 +64,7 @@ class AddMultiplyDivide(Operator):
) )
][0] ][0]
print(response, flush=True) self._logger.debug(response)
for output_name, output_value in response.outputs.items(): for output_name, output_value in response.outputs.items():
outputs[f"{response.model_name}_{output_name}"] = output_value outputs[f"{response.model_name}_{output_name}"] = output_value
...@@ -88,8 +89,8 @@ class AddMultiplyDivide(Operator): ...@@ -88,8 +89,8 @@ class AddMultiplyDivide(Operator):
for result in asyncio.as_completed([multiply_respnoses, divide_responses]): for result in asyncio.as_completed([multiply_respnoses, divide_responses]):
responses = await result responses = await result
async for response in responses: async for response in responses:
print("response!", response, flush=True) self._logger.debug(f"response! {response}")
print("error!", response.error, flush=True) self._logger.debug(f"error! {response.error}")
if response.error is not None: if response.error is not None:
error = response.error error = response.error
break break
......
...@@ -39,16 +39,19 @@ class Identity(Operator): ...@@ -39,16 +39,19 @@ class Identity(Operator):
self._request_plane = request_plane self._request_plane = request_plane
self._data_plane = data_plane self._data_plane = data_plane
self._params = params self._params = params
self._logger = logger
async def execute(self, requests: list[RemoteInferenceRequest]): async def execute(self, requests: list[RemoteInferenceRequest]):
for request in requests: for request in requests:
try: try:
array = numpy.from_dlpack(request.inputs["input"]) array = numpy.from_dlpack(request.inputs["input"])
except Exception as e: except Exception as e:
print(e) self.logger.exception("Failed to retrieve inputs")
await request.response_sender().send(final=True, error=e) await request.response_sender().send(final=True, error=e)
return return
self._logger.debug("Operator received inputs")
outputs: dict[str, numpy.ndarray] = {"output": array} outputs: dict[str, numpy.ndarray] = {"output": array}
store_outputs_in_response = False store_outputs_in_response = False
......
...@@ -119,7 +119,7 @@ class MockDisaggregatedServing(Operator): ...@@ -119,7 +119,7 @@ class MockDisaggregatedServing(Operator):
return sending return sending
async def execute(self, requests: list[RemoteInferenceRequest]): async def execute(self, requests: list[RemoteInferenceRequest]):
print("in execute!", flush=True) self._logger.debug("in execute!")
error = None error = None
for request in requests: for request in requests:
""" """
......
...@@ -40,6 +40,7 @@ class TritonPythonModel: ...@@ -40,6 +40,7 @@ class TritonPythonModel:
# Using a mock hard coded auto-tokenizer # Using a mock hard coded auto-tokenizer
self.tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased") self.tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
self._logger = pb_utils.Logger
def execute(self, requests): def execute(self, requests):
responses = [] responses = []
...@@ -52,7 +53,7 @@ class TritonPythonModel: ...@@ -52,7 +53,7 @@ class TritonPythonModel:
output_result = np.array( output_result = np.array(
self.tokenizer.convert_ids_to_tokens((output_ids.tolist())) self.tokenizer.convert_ids_to_tokens((output_ids.tolist()))
) )
print(f"Output Result \n\n {output_result}", flush=True) self._logger.log_verbose(f"Output Result \n\n {output_result}")
output_tensor = pb_utils.Tensor( output_tensor = pb_utils.Tensor(
"OUTPUT", output_result.astype(self.output_dtype) "OUTPUT", output_result.astype(self.output_dtype)
......
...@@ -44,9 +44,10 @@ class TritonPythonModel: ...@@ -44,9 +44,10 @@ class TritonPythonModel:
# Using a mock hard coded auto-tokenizer # Using a mock hard coded auto-tokenizer
self.tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased") self.tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
self.logger = pb_utils.Logger
def execute(self, requests): def execute(self, requests):
print("In preprocessing execute!", flush=True) self.logger.log_verbose("In preprocessing execute!")
responses = [] responses = []
for idx, request in enumerate(requests): for idx, request in enumerate(requests):
...@@ -56,9 +57,9 @@ class TritonPythonModel: ...@@ -56,9 +57,9 @@ class TritonPythonModel:
request, "request_output_len" request, "request_output_len"
).as_numpy() ).as_numpy()
print(f"query(pre-proc) {query}", flush=True) self.logger.log_verbose(f"query(pre-proc) {query}")
tokenize = np.array(self.tokenizer.encode(query[0].decode())) tokenize = np.array(self.tokenizer.encode(query[0].decode()))
print(f"tokenize(pre-proc) {tokenize.size}", flush=True) self.logger.log_verbose(f"tokenize(pre-proc) {tokenize.size}")
input_length = np.array([tokenize.size]) input_length = np.array([tokenize.size])
# Just forwarding query to the pre-processed input_ids # Just forwarding query to the pre-processed input_ids
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import asyncio import asyncio
import logging
import sys import sys
from multiprocessing import Process from multiprocessing import Process
...@@ -27,7 +26,7 @@ from cupy_backends.cuda.api.runtime import CUDARuntimeError ...@@ -27,7 +26,7 @@ from cupy_backends.cuda.api.runtime import CUDARuntimeError
from triton_distributed.icp.nats_request_plane import NatsRequestPlane from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.worker.deployment import Deployment from triton_distributed.worker.deployment import Deployment
from triton_distributed.worker.log_formatter import LOGGER_NAME from triton_distributed.worker.logger import get_logger
from triton_distributed.worker.operator import OperatorConfig from triton_distributed.worker.operator import OperatorConfig
from triton_distributed.worker.remote_operator import RemoteOperator from triton_distributed.worker.remote_operator import RemoteOperator
from triton_distributed.worker.triton_core_operator import TritonCoreOperator from triton_distributed.worker.triton_core_operator import TritonCoreOperator
...@@ -38,10 +37,9 @@ MODEL_REPOSITORY = ( ...@@ -38,10 +37,9 @@ MODEL_REPOSITORY = (
"/workspace/worker/tests/python/integration/operators/triton_core_models" "/workspace/worker/tests/python/integration/operators/triton_core_models"
) )
OPERATORS_REPOSITORY = "/workspace/worker/tests/python/integration/operators" OPERATORS_REPOSITORY = "/workspace/worker/tests/python/integration/operators"
TRITON_LOG_FILE = "triton.log"
TRITON_LOG_LEVEL = 6 TRITON_LOG_LEVEL = 6
logger = logging.getLogger(LOGGER_NAME) logger = get_logger(__name__)
# Run cupy's cuda.is_available once to # Run cupy's cuda.is_available once to
# avoid the exception hitting runtime code. # avoid the exception hitting runtime code.
...@@ -98,6 +96,7 @@ def workers(request, log_dir): ...@@ -98,6 +96,7 @@ def workers(request, log_dir):
worker_log_dir = test_log_dir / name worker_log_dir = test_log_dir / name
worker_configs.append( worker_configs.append(
WorkerConfig( WorkerConfig(
name=name,
request_plane=NatsRequestPlane, request_plane=NatsRequestPlane,
data_plane=UcpDataPlane, data_plane=UcpDataPlane,
request_plane_args=( request_plane_args=(
...@@ -106,7 +105,6 @@ def workers(request, log_dir): ...@@ -106,7 +105,6 @@ def workers(request, log_dir):
), ),
log_level=TRITON_LOG_LEVEL, log_level=TRITON_LOG_LEVEL,
log_dir=str(worker_log_dir), log_dir=str(worker_log_dir),
triton_log_path=str(worker_log_dir / TRITON_LOG_FILE),
operators=[operator_config], operators=[operator_config],
) )
) )
...@@ -223,7 +221,11 @@ def run(num_requests, store_inputs_in_request=False): ...@@ -223,7 +221,11 @@ def run(num_requests, store_inputs_in_request=False):
[(False, False), (True, True)], [(False, False), (True, True)],
) )
def test_add_multiply_divide( def test_add_multiply_divide(
request, nats_server, workers, store_inputs_in_request, store_outputs_in_response request,
nats_server,
workers,
store_inputs_in_request,
store_outputs_in_response,
): ):
# Using a separate process to use data plane across multiple tests. # Using a separate process to use data plane across multiple tests.
p = Process(target=run, args=(2, store_inputs_in_request)) p = Process(target=run, args=(2, store_inputs_in_request))
......
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import pathlib
import sys
from multiprocessing import Process
import cupy
import numpy
import pytest
import ucp
from cupy_backends.cuda.api.runtime import CUDARuntimeError
from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.worker.deployment import Deployment
from triton_distributed.worker.logger import get_logger
from triton_distributed.worker.operator import OperatorConfig
from triton_distributed.worker.remote_operator import RemoteOperator
from triton_distributed.worker.triton_core_operator import TritonCoreOperator
from triton_distributed.worker.worker import WorkerConfig
NATS_PORT = 4223
MODEL_REPOSITORY = (
"/workspace/worker/tests/python/integration/operators/triton_core_models"
)
OPERATORS_REPOSITORY = "/workspace/worker/tests/python/integration/operators"
TRITON_LOG_LEVEL = 6
logger = get_logger(__name__)
# Run cupy's cuda.is_available once to
# avoid the exception hitting runtime code.
try:
if cupy.cuda.is_available():
pass
else:
print("CUDA not available.")
except CUDARuntimeError:
print("CUDA not available")
# TODO
# Decide if this should be
# pre merge, nightly, or weekly
pytestmark = pytest.mark.pre_merge
@pytest.fixture
def workers(request, log_dir):
operator_configs = {}
# Add configs for triton core operators
triton_core_operators = ["add", "multiply", "divide"]
for operator_name in triton_core_operators:
operator_configs[operator_name] = OperatorConfig(
name=operator_name,
implementation=TritonCoreOperator,
version=1,
max_inflight_requests=10,
repository=MODEL_REPOSITORY,
)
# Add configs for other custom operators
operator_name = "add_multiply_divide"
operator_configs[operator_name] = OperatorConfig(
name=operator_name,
implementation="add_multiply_divide:AddMultiplyDivide",
version=1,
max_inflight_requests=10,
repository=OPERATORS_REPOSITORY,
)
worker_configs = []
test_log_dir = log_dir / request.node.name
test_log_dir.mkdir(parents=True, exist_ok=True)
# We will instantiate a worker for each operator
for name, operator_config in operator_configs.items():
# Set the logging directory
worker_log_dir = test_log_dir / name
worker_configs.append(
WorkerConfig(
name=name,
request_plane=NatsRequestPlane,
data_plane=UcpDataPlane,
request_plane_args=(
[],
{"request_plane_uri": f"nats://localhost:{NATS_PORT}"},
),
log_level=TRITON_LOG_LEVEL,
log_dir=str(worker_log_dir),
operators=[operator_config],
)
)
consolidate_logs = request.getfixturevalue("consolidate_logs")
worker_deployment = Deployment(
worker_configs,
consolidate_logs=consolidate_logs,
log_dir=log_dir,
)
worker_deployment.start()
yield worker_deployment
worker_deployment.shutdown()
def _create_inputs(number, size):
inputs = []
outputs = []
for index in range(number):
input_ = numpy.random.randint(low=1, high=100, size=[2, size])
expected_ = {}
expected_["add_int64_output_total"] = numpy.array([[input_.sum()]])
expected_["add_int64_output_partial"] = numpy.array([[x.sum() for x in input_]])
expected_["multiply_int64_output_total"] = numpy.array(
[[x.prod() for x in expected_["add_int64_output_partial"]]]
)
divisor = expected_["add_int64_output_total"][0][0]
dividends = expected_["add_int64_output_partial"]
expected_["divide_fp64_output_partial"] = numpy.array(
[numpy.divide(dividends, divisor)]
)
inputs.append(input_)
outputs.append(expected_)
return inputs, outputs
async def post_requests(num_requests):
"""
Post requests to add_multiply_divide operator.
"""
ucp.reset()
timeout = 5
data_plane = UcpDataPlane()
data_plane.connect()
request_plane = NatsRequestPlane(f"nats://localhost:{NATS_PORT}")
await request_plane.connect()
add_multiply_divide_operator = RemoteOperator(
"add_multiply_divide", request_plane, data_plane
)
results = []
expected_results = {}
inputs, outputs = _create_inputs(num_requests, 40)
for i, input_ in enumerate(inputs):
request_id = str(i)
request = add_multiply_divide_operator.create_request(
inputs={"int64_input": input_}, request_id=request_id
)
print(request)
results.append(add_multiply_divide_operator.async_infer(request))
expected_results[request_id] = outputs[i]
for result in asyncio.as_completed(results):
responses = await result
async for response in responses:
print(response)
for output_name, expected_value in expected_results[
response.request_id
].items():
output = response.outputs[output_name]
output_value = numpy.from_dlpack(output.to_host())
numpy.testing.assert_equal(output_value, expected_value)
del output
print(expected_results[response.request_id])
del response
timeout = 5
data_plane.close(timeout)
await request_plane.close()
def run(num_requests):
sys.exit(asyncio.run(post_requests(num_requests=num_requests)))
@pytest.mark.skipif(
"(not os.path.exists('/usr/local/bin/nats-server'))",
reason="NATS.io not present",
)
@pytest.mark.timeout(30)
@pytest.mark.parametrize(
"consolidate_logs",
[True, False],
)
def test_consolidate_logs(request, nats_server, workers, consolidate_logs, log_dir):
# Using a separate process to use data plane across multiple tests.
p = Process(target=run, args=(2,))
p.start()
p.join()
assert p.exitcode == 0
# Test the number of logs that were created
log_dir_path = pathlib.Path(log_dir) / request.node.name
worker_log_dir_count = 0
for name in log_dir_path.iterdir():
worker_log_dir_count += 1
expected_worker_log_count = 1
if not consolidate_logs and name.stem not in ["add_multiply_divide"]:
expected_worker_log_count = 2
worker_log_path = log_dir_path / name.stem
worker_log_count = 0
for log_name in worker_log_path.iterdir():
worker_log_count += 1
assert worker_log_count == expected_worker_log_count
assert worker_log_dir_count == 4
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import asyncio import asyncio
import logging
import sys import sys
import uuid import uuid
from multiprocessing import Process from multiprocessing import Process
...@@ -28,7 +27,7 @@ from cupy_backends.cuda.api.runtime import CUDARuntimeError ...@@ -28,7 +27,7 @@ from cupy_backends.cuda.api.runtime import CUDARuntimeError
from triton_distributed.icp.nats_request_plane import NatsRequestPlane from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.worker.deployment import Deployment from triton_distributed.worker.deployment import Deployment
from triton_distributed.worker.log_formatter import LOGGER_NAME from triton_distributed.worker.logger import get_logger
from triton_distributed.worker.operator import OperatorConfig from triton_distributed.worker.operator import OperatorConfig
from triton_distributed.worker.remote_operator import RemoteOperator from triton_distributed.worker.remote_operator import RemoteOperator
from triton_distributed.worker.worker import WorkerConfig from triton_distributed.worker.worker import WorkerConfig
...@@ -38,10 +37,9 @@ MODEL_REPOSITORY = ( ...@@ -38,10 +37,9 @@ MODEL_REPOSITORY = (
"/workspace/worker/tests/python/integration/operators/triton_core_models" "/workspace/worker/tests/python/integration/operators/triton_core_models"
) )
OPERATORS_REPOSITORY = "/workspace/worker/tests/python/integration/operators" OPERATORS_REPOSITORY = "/workspace/worker/tests/python/integration/operators"
TRITON_LOG_FILE = "triton.log"
TRITON_LOG_LEVEL = 6 TRITON_LOG_LEVEL = 6
logger = logging.getLogger(LOGGER_NAME) logger = get_logger(__name__)
# Run cupy's cuda.is_available once to # Run cupy's cuda.is_available once to
# avoid the exception hitting runtime code. # avoid the exception hitting runtime code.
...@@ -81,6 +79,7 @@ def workers(request, log_dir, number_workers=10): ...@@ -81,6 +79,7 @@ def workers(request, log_dir, number_workers=10):
worker_log_dir = test_log_dir / (operator_name + "_" + str(i)) worker_log_dir = test_log_dir / (operator_name + "_" + str(i))
worker_configs.append( worker_configs.append(
WorkerConfig( WorkerConfig(
name=operator_name,
request_plane=NatsRequestPlane, request_plane=NatsRequestPlane,
data_plane=UcpDataPlane, data_plane=UcpDataPlane,
request_plane_args=( request_plane_args=(
...@@ -89,7 +88,6 @@ def workers(request, log_dir, number_workers=10): ...@@ -89,7 +88,6 @@ def workers(request, log_dir, number_workers=10):
), ),
log_level=TRITON_LOG_LEVEL, log_level=TRITON_LOG_LEVEL,
log_dir=str(worker_log_dir), log_dir=str(worker_log_dir),
triton_log_path=str(worker_log_dir / TRITON_LOG_FILE),
operators=[operator_config], operators=[operator_config],
) )
) )
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import asyncio import asyncio
import logging
import queue import queue
import sys import sys
import time import time
...@@ -34,7 +33,7 @@ from tritonserver import Tensor ...@@ -34,7 +33,7 @@ from tritonserver import Tensor
from triton_distributed.icp.nats_request_plane import NatsRequestPlane from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.worker.deployment import Deployment from triton_distributed.worker.deployment import Deployment
from triton_distributed.worker.log_formatter import LOGGER_NAME from triton_distributed.worker.logger import get_logger
from triton_distributed.worker.operator import OperatorConfig from triton_distributed.worker.operator import OperatorConfig
from triton_distributed.worker.remote_operator import RemoteOperator from triton_distributed.worker.remote_operator import RemoteOperator
from triton_distributed.worker.triton_core_operator import TritonCoreOperator from triton_distributed.worker.triton_core_operator import TritonCoreOperator
...@@ -45,10 +44,9 @@ MODEL_REPOSITORY = ( ...@@ -45,10 +44,9 @@ MODEL_REPOSITORY = (
"/workspace/worker/tests/python/integration/operators/triton_core_models" "/workspace/worker/tests/python/integration/operators/triton_core_models"
) )
OPERATORS_REPOSITORY = "/workspace/worker/tests/python/integration/operators" OPERATORS_REPOSITORY = "/workspace/worker/tests/python/integration/operators"
TRITON_LOG_FILE = "triton.log"
TRITON_LOG_LEVEL = 6 TRITON_LOG_LEVEL = 6
logger = logging.getLogger(LOGGER_NAME) logger = get_logger(__name__)
# Run cupy's cuda.is_available once to # Run cupy's cuda.is_available once to
# avoid the exception hitting runtime code. # avoid the exception hitting runtime code.
...@@ -102,6 +100,7 @@ def workers(request, log_dir): ...@@ -102,6 +100,7 @@ def workers(request, log_dir):
worker_log_dir = test_log_dir / name worker_log_dir = test_log_dir / name
worker_configs.append( worker_configs.append(
WorkerConfig( WorkerConfig(
name=name,
request_plane=NatsRequestPlane, request_plane=NatsRequestPlane,
data_plane=UcpDataPlane, data_plane=UcpDataPlane,
request_plane_args=( request_plane_args=(
...@@ -110,7 +109,6 @@ def workers(request, log_dir): ...@@ -110,7 +109,6 @@ def workers(request, log_dir):
), ),
log_level=TRITON_LOG_LEVEL, log_level=TRITON_LOG_LEVEL,
log_dir=str(worker_log_dir), log_dir=str(worker_log_dir),
triton_log_path=str(worker_log_dir / TRITON_LOG_FILE),
operators=[operator_config], operators=[operator_config],
) )
) )
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
import asyncio import asyncio
import logging
import numpy import numpy
import pytest import pytest
...@@ -35,7 +34,7 @@ import ucp ...@@ -35,7 +34,7 @@ import ucp
from triton_distributed.icp.nats_request_plane import NatsRequestPlane from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.worker.deployment import Deployment from triton_distributed.worker.deployment import Deployment
from triton_distributed.worker.log_formatter import LOGGER_NAME from triton_distributed.worker.logger import get_logger
from triton_distributed.worker.operator import OperatorConfig from triton_distributed.worker.operator import OperatorConfig
from triton_distributed.worker.remote_operator import RemoteOperator from triton_distributed.worker.remote_operator import RemoteOperator
from triton_distributed.worker.worker import WorkerConfig from triton_distributed.worker.worker import WorkerConfig
...@@ -45,10 +44,9 @@ MODEL_REPOSITORY = ( ...@@ -45,10 +44,9 @@ MODEL_REPOSITORY = (
"/workspace/worker/tests/python/integration/operators/triton_core_models" "/workspace/worker/tests/python/integration/operators/triton_core_models"
) )
OPERATORS_REPOSITORY = "/workspace/worker/tests/python/integration/operators" OPERATORS_REPOSITORY = "/workspace/worker/tests/python/integration/operators"
TRITON_LOG_FILE = "triton.log" TRITON_LOG_LEVEL = 0
TRITON_LOG_LEVEL = 6
logger = logging.getLogger(LOGGER_NAME) logger = get_logger(__name__)
# TODO # TODO
# Decide if this should be # Decide if this should be
...@@ -81,6 +79,7 @@ def workers(log_dir, request, number_workers=1): ...@@ -81,6 +79,7 @@ def workers(log_dir, request, number_workers=1):
worker_log_dir = test_log_dir / (operator_name + "_" + str(i)) worker_log_dir = test_log_dir / (operator_name + "_" + str(i))
worker_configs.append( worker_configs.append(
WorkerConfig( WorkerConfig(
name=operator_name,
request_plane=NatsRequestPlane, request_plane=NatsRequestPlane,
data_plane=UcpDataPlane, data_plane=UcpDataPlane,
request_plane_args=( request_plane_args=(
...@@ -89,7 +88,6 @@ def workers(log_dir, request, number_workers=1): ...@@ -89,7 +88,6 @@ def workers(log_dir, request, number_workers=1):
), ),
log_level=TRITON_LOG_LEVEL, log_level=TRITON_LOG_LEVEL,
log_dir=str(worker_log_dir), log_dir=str(worker_log_dir),
triton_log_path=str(worker_log_dir / TRITON_LOG_FILE),
operators=[operator_config], operators=[operator_config],
) )
) )
...@@ -222,7 +220,7 @@ def data_plane_tracker(): ...@@ -222,7 +220,7 @@ def data_plane_tracker():
"tensor_size_in_kb", "tensor_size_in_kb",
[10, 100, 500], [10, 100, 500],
) )
@pytest.mark.benchmark(min_rounds=50, max_time=0.5) @pytest.mark.benchmark(min_rounds=100, max_time=1)
def test_identity( def test_identity(
request, request,
nats_server, nats_server,
......
...@@ -17,9 +17,9 @@ import logging ...@@ -17,9 +17,9 @@ import logging
import pytest import pytest
from triton_distributed.worker.log_formatter import LOGGER_NAME, setup_logger from triton_distributed.worker.logger import get_logger
logger = logging.getLogger(LOGGER_NAME) logger = logging.getLogger(__name__)
MSG = "This is a sample message" MSG = "This is a sample message"
...@@ -67,6 +67,6 @@ def reset_logger(caplog): ...@@ -67,6 +67,6 @@ def reset_logger(caplog):
) )
def test_logging(reset_logger, caplog, log_level, expected_record_counts): def test_logging(reset_logger, caplog, log_level, expected_record_counts):
caplog.set_level(log_level) caplog.set_level(log_level)
setup_logger(log_level=log_level) logger = get_logger(logger_name="test_logging", log_level=log_level)
logging_function(logger) logging_function(logger)
assert len(caplog.records) == expected_record_counts assert len(caplog.records) == expected_record_counts
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment