Commit 0bfd9a76 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

refactor: remove python native runtime

parent 8f741f14
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Sequence
import cupy
from cupy_backends.cuda.api.runtime import CUDARuntimeError
from triton_distributed.icp._dlpack import DeviceOrMemoryType, DLDeviceType
from triton_distributed.icp.data_plane import (
DataPlane,
get_icp_data_type,
get_icp_memory_type,
get_icp_shape,
get_icp_tensor_size,
)
from triton_distributed.icp.data_type import DataType
from triton_distributed.icp.memory_type import MemoryType
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest, ModelInferResponse
from triton_distributed.icp.tensor import Tensor
# Run cupy's cuda.is_available once to
# avoid the exception hitting runtime code.
try:
cupy.cuda.is_available()
except CUDARuntimeError:
pass
@dataclass
class RemoteTensor:
remote_tensor: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor
data_plane: DataPlane
_local_tensor: Optional[Tensor] = None
# FIXME: This is a hack to avoid double deletion of the tensor
# Tensor must be explicitly released by the user before data plane connection is closed
deleted: bool = False
@property
def data_type(self) -> DataType | None:
return get_icp_data_type(self.remote_tensor)
@property
def shape(self) -> Sequence[int] | None:
return get_icp_shape(self.remote_tensor)
@property
def memory_type(self) -> MemoryType | None:
return get_icp_memory_type(self.remote_tensor)
@property
def size(self) -> int | None:
return get_icp_tensor_size(self.remote_tensor)
@property
def local_tensor(self) -> Tensor:
if not self._local_tensor:
self._local_tensor = self.data_plane.get_tensor(self.remote_tensor)
if self._local_tensor is None:
raise ValueError("Not able to resolve Tensor locally")
return self._local_tensor
@property
def data_ptr(self) -> int:
return self.local_tensor.data_ptr
def __dlpack__(self, *, stream=None):
return self.local_tensor.__dlpack__(stream=stream)
def __dlpack_device__(self) -> tuple[DLDeviceType, int]:
return self.local_tensor.__dlpack_device__()
def to_string_array(self):
return self.local_tensor.to_string_array()
def to_bytes_array(self):
return self.local_tensor.to_bytes_array()
def to_host(self) -> Tensor:
return self.local_tensor.to_host()
def to_device(self, device: DeviceOrMemoryType) -> Tensor:
return self.local_tensor.to_device(device)
def __del__(self):
# FIXME: This is a hack to avoid double deletion of the tensor
# Tensor must be explicitly released by the user before data plane connection is closed
if not self.deleted:
self.data_plane.release_tensor(self.remote_tensor)
self.deleted = True
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import logging
import os
import uuid
from typing import Optional
try:
import tritonserver
from tritonserver import DataType as TritonDataType
from tritonserver import InvalidArgumentError
from tritonserver import MemoryBuffer as TritonMemoryBuffer
from tritonserver import MemoryType as TritonMemoryType
from tritonserver import Server as TritonCore
from tritonserver import Tensor as TritonTensor
from tritonserver._api._response import InferenceResponse
except ImportError as e:
raise ImportError("Triton Core is not installed") from e
from google.protobuf import json_format, text_format
from tritonclient.grpc import model_config_pb2
from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.icp.tensor import Tensor
from triton_distributed.runtime.logger import get_logger
from triton_distributed.runtime.operator import Operator
from triton_distributed.runtime.remote_request import RemoteInferenceRequest
from triton_distributed.runtime.remote_response import RemoteInferenceResponse
class TritonCoreOperator(Operator):
def __init__(
self,
name: str,
version: int,
request_plane: RequestPlane,
data_plane: DataPlane,
parameters: dict,
repository: Optional[str] = None,
logger: logging.Logger = get_logger(__name__),
triton_core: Optional[TritonCore] = None,
):
self._repository = repository
self._name = name
self._parameters = parameters
self._triton_core = triton_core
self._version = version
self._logger = logger
self._request_plane = request_plane
self._data_plane = data_plane
self._store_outputs_in_response = self._parameters.get(
"store_outputs_in_response", False
)
if self._triton_core is None:
raise ValueError("Triton Core required for TritonCoreOperator")
if not self._repository:
self._repository = "."
if repository:
self._triton_core.register_model_repository(repository)
parameter_config = self._parameters.get("config", {})
if "parameters" not in parameter_config:
parameter_config["parameters"] = {}
parameter_config["parameters"]["component_id"] = {
"string_value": f"{self._request_plane.component_id}"
}
model_config = None
try:
model_config_path = os.path.join(
self._repository, self._name, "config.pbtxt"
)
with open(model_config_path, "r") as config_file:
model_config = text_format.Parse(
config_file.read(), model_config_pb2.ModelConfig()
)
except Exception:
pass
parameter_config = json_format.Parse(
json.dumps(parameter_config), model_config_pb2.ModelConfig()
)
if model_config:
model_config.MergeFrom(parameter_config)
else:
model_config = parameter_config
model_config = {"config": json_format.MessageToJson(model_config)}
self._triton_core_model = self._triton_core.load(self._name, model_config)
@staticmethod
def _triton_tensor(tensor: Tensor) -> TritonTensor:
return TritonTensor(
TritonDataType(tensor.data_type),
tensor.shape,
TritonMemoryBuffer(
tensor.memory_buffer.data_ptr,
TritonMemoryType(tensor.memory_buffer.memory_type),
tensor.memory_buffer.memory_type_id,
tensor.memory_buffer.size,
tensor.memory_buffer.owner,
),
)
@staticmethod
def _triton_core_request(
request: RemoteInferenceRequest, model: tritonserver.Model
) -> tritonserver.InferenceRequest:
triton_core_request = model.create_request()
if request.request_id is not None:
triton_core_request.request_id = request.request_id
if request.priority is not None:
triton_core_request.priority = request.priority
if request.timeout is not None:
triton_core_request.timeout = request.timeout
if request.correlation_id is not None:
triton_core_request.correlation_id = request.correlation_id
TritonCoreOperator._set_inputs(request, triton_core_request)
TritonCoreOperator._set_parameters(request, triton_core_request)
return triton_core_request
@staticmethod
def _set_inputs(
request: RemoteInferenceRequest, local_request: tritonserver.InferenceRequest
):
for input_name, remote_tensor in request.inputs.items():
local_request.inputs[input_name] = TritonCoreOperator._triton_tensor(
remote_tensor.local_tensor
)
@staticmethod
def _set_parameters(
request: RemoteInferenceRequest, local_request: tritonserver.InferenceRequest
):
for parameter_name, parameter_value in request.parameters.items():
local_request.parameters[parameter_name] = parameter_value
@staticmethod
def _remote_response(
triton_core_response: InferenceResponse, store_outputs_in_response: bool = False
) -> RemoteInferenceResponse:
result = RemoteInferenceResponse(
triton_core_response.model.name,
triton_core_response.model.version,
None,
triton_core_response.request_id,
final=triton_core_response.final,
)
for tensor_name, tensor_value in triton_core_response.outputs.items():
result.outputs[tensor_name] = tensor_value
if store_outputs_in_response:
result.store_outputs_in_response.add(tensor_name)
for parameter_name, parameter_value in triton_core_response.parameters.items():
result.parameters[parameter_name] = parameter_value
result.error = triton_core_response.error
return result
async def execute(self, requests: list[RemoteInferenceRequest]) -> None:
request_id_map = {}
response_queue: asyncio.Queue = asyncio.Queue()
for request in requests:
self._logger.debug("\n\nReceived request: \n\n%s\n\n", request)
try:
triton_core_request = TritonCoreOperator._triton_core_request(
request, self._triton_core_model
)
except Exception as e:
message = f"Can't resolve tensors for request, ignoring request,{e}"
self._logger.error(message)
await request.response_sender().send(
error=InvalidArgumentError(message), final=True
)
continue
request_id = str(uuid.uuid1())
original_id = None
if triton_core_request.request_id is not None:
original_id = triton_core_request.request_id
triton_core_request.request_id = request_id
request_id_map[request_id] = (request.response_sender(), original_id)
triton_core_request.response_queue = response_queue
self._triton_core_model.async_infer(triton_core_request)
while request_id_map:
triton_core_response = await response_queue.get()
remote_response = TritonCoreOperator._remote_response(
triton_core_response, self._store_outputs_in_response
)
response_sender, original_id = request_id_map[
triton_core_response.request_id
]
remote_response.request_id = original_id
if triton_core_response.final:
del request_id_map[triton_core_response.request_id]
self._logger.debug("\n\nSending response\n\n%s\n\n", remote_response)
await response_sender.send(remote_response)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import importlib
import os
import pathlib
import signal
import sys
import uuid
from collections import Counter
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional, Type
try:
from tritonserver import ModelControlMode as ModelControlMode
from tritonserver import Server as TritonCore
from triton_distributed.runtime.triton_core_operator import TritonCoreOperator
TRITON_CORE_AVAILABLE = True
except ImportError:
TRITON_CORE_AVAILABLE = False
TritonCoreOperator = type(None)
TritonCore = type(None) # type: ignore[misc,assignment]
from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.runtime.logger import get_logger, get_logger_config
from triton_distributed.runtime.operator import Operator, OperatorConfig
from triton_distributed.runtime.remote_request import (
RemoteInferenceRequest,
RemoteResponseSender,
)
if TYPE_CHECKING:
import uvicorn
logger = get_logger(__name__)
@dataclass
class WorkerConfig:
request_plane: Type[RequestPlane] = NatsRequestPlane
data_plane: Type[DataPlane] = UcpDataPlane
request_plane_args: tuple[list, dict] = field(default_factory=lambda: ([], {}))
data_plane_args: tuple[list, dict] = field(default_factory=lambda: ([], {}))
log_level: Optional[int] = None
operators: list[OperatorConfig] = field(default_factory=list)
name: str = str(uuid.uuid1())
log_dir: Optional[str] = None
consolidate_logs = False
metrics_port: int = 0
class Worker:
def __init__(
self, config: Optional[WorkerConfig] = None, **kwargs #: Unpack[WorkerConfig]
) -> None:
if config is None:
config = WorkerConfig(**kwargs)
self._request_plane = config.request_plane(
*config.request_plane_args[0], **config.request_plane_args[1]
)
self._data_plane = config.data_plane(
*config.data_plane_args[0], **config.data_plane_args[1]
)
self._name = config.name
self._log_level = config.log_level
if self._log_level is None:
self._log_level = 0
self._operator_configs = config.operators
self._log_dir = config.log_dir
self._consolidate_logs = config.consolidate_logs
self._stop_requested = False
self._requests_received: Counter = Counter()
self._background_tasks: dict[object, set] = {}
self._completion_conds: dict[object, asyncio.Condition] = {}
self._inflight_requests: dict[object, int] = {}
self._max_inflght_requests: dict[object, int] = {}
self._operators: dict[tuple[str, int], Operator] = {}
self._metrics_port = config.metrics_port
self._metrics_server: Optional[uvicorn.Server] = None
self._component_id = self._request_plane.component_id
self._triton_core: Optional[TritonCore] = None
self._log_file: Optional[pathlib.Path] = None
if self._log_dir:
path = pathlib.Path(self._log_dir)
path.mkdir(parents=True, exist_ok=True)
pid = os.getpid()
self._log_file = path / f"{self._name}.{self._component_id}.{pid}.log"
def _import_operators(self):
for operator_config in self._operator_configs:
if operator_config.repository:
repository_path = pathlib.Path(operator_config.repository)
sys.path.append(str(repository_path.absolute()))
else:
repository_path = pathlib.Path(".")
if isinstance(operator_config.implementation, str):
split_workflow = operator_config.implementation.split(":")
module_name = ":".join(split_workflow[:-1])
class_name = split_workflow[-1]
module_path = pathlib.Path(module_name)
parent_paths = list(module_path.parents)
root_parent = pathlib.Path(".")
if parent_paths:
root_parent = parent_paths[-1]
if root_parent == pathlib.Path("."):
module_path = repository_path.joinpath(module_path)
if str(module_path.parent.absolute()) not in sys.path:
sys.path.append(str(module_path.parent.absolute()))
try:
module = importlib.import_module(module_path.name)
class_ = getattr(module, class_name)
except Exception as e:
logger.exception(
"can't instantiate operator: %s %s", operator_config.name, e
)
raise e
elif issubclass(operator_config.implementation, Operator):
class_ = operator_config.implementation
else:
logger.exception(
"can't instantiate operator: %s",
operator_config.name,
)
raise Exception("invalid implementation type")
try:
if operator_config.log_level is None:
operator_config.log_level = self._log_level
operator_logger = get_logger(
log_level=operator_config.log_level,
logger_name=f"OPERATOR{(operator_config.name,operator_config.version)}",
log_file=self._log_file,
)
if (
class_ == TritonCoreOperator
or issubclass(class_, TritonCoreOperator)
) and not self._triton_core:
if not TRITON_CORE_AVAILABLE:
raise ValueError(
"Please install Triton Core to use a Triton Core Operator"
)
if not self._consolidate_logs and self._log_file:
log_file = pathlib.Path(self._log_file)
stem = log_file.stem
suffix = log_file.suffix
triton_log_path = str(
log_file.parent / f"{stem}.triton{suffix}"
)
else:
triton_log_path = str(self._log_file)
self._triton_core = TritonCore(
model_repository=".",
log_error=True,
log_verbose=self._log_level,
strict_model_config=False,
model_control_mode=ModelControlMode.EXPLICIT,
log_file=triton_log_path,
).start(wait_until_ready=True)
operator = class_(
operator_config.name,
operator_config.version,
self._request_plane,
self._data_plane,
operator_config.parameters,
operator_config.repository,
operator_logger,
self._triton_core,
)
except Exception as e:
logger.exception(
"can't instantiate operator: %s %s", operator_config.name, e
)
raise e
operator_key = (operator_config.name, operator_config.version)
self._operators[operator_key] = operator
self._max_inflght_requests[operator] = operator_config.max_inflight_requests
self._inflight_requests[operator] = 0
self._background_tasks[operator] = set()
self._completion_conds[operator] = asyncio.Condition()
async def _process_request(self, request):
logger.debug("\n\nserver received request: \n\n%s\n\n", request)
operator_key = (request.model_name, int(request.model_version))
if operator_key in self._operators:
operator = self._operators[operator_key]
self._requests_received[operator] += 1
remote_request = RemoteInferenceRequest.from_model_infer_request(
request, self._data_plane, self._request_plane
)
await operator.execute([remote_request])
else:
logger.warning("Received request for unknown operator")
async def _process_request_task(self, operator, name, version):
requests = await self._request_plane.pull_requests(name, str(version))
# When the request is received, notify the handler to
# pull next requests if capacity permits.
async with self._completion_conds[operator]:
self._inflight_requests[operator] += 1
logger.debug(f"{operator} inflight: {self._inflight_requests[operator]}")
self._completion_conds[operator].notify()
# Process request received from the request plane
async for request in requests:
await self._process_request(request)
# The request is processed and new requests may be
# pulled.
async with self._completion_conds[operator]:
self._inflight_requests[operator] -= 1
logger.debug(f"{operator} inflight {self._inflight_requests[operator]}")
self._completion_conds[operator].notify()
async def _add_process_request_task(self, operator, name, version):
task = asyncio.create_task(self._process_request_task(operator, name, version))
self._background_tasks[operator].add(task)
task.add_done_callback(self._background_tasks[operator].discard)
async def _request_handler(self, operator, name, version):
while not self._stop_requested:
async with self._completion_conds[operator]:
# TODO: Instead of pulling a fixed number of requests try
# querying the model status to understand whether or not
# to pull more requests.
if (
self._inflight_requests[operator]
< self._max_inflght_requests[operator]
):
await self._add_process_request_task(operator, name, version)
# Block the handler till task is notified
# We want to create new tasks only when they
# are needed so that at a given time, there
# is only a single model task pulling from the
# request plane.
await self._completion_conds[operator].wait()
async def _initialize_request_handlers(self):
handlers = []
for (name, version), operator in self._operators.items():
logger.info(f"Starting {name} handler...")
handlers.append(self._request_handler(operator, name, version))
await asyncio.gather(*handlers)
async def serve(self):
try:
await self._request_plane.connect()
except Exception as e:
logger.exception(
"Encountered an error when trying to connect to request plane"
)
raise e
try:
self._data_plane.connect()
except Exception as e:
logger.exception(
"Encountered and error when trying to connect to data plane"
)
raise e
error = None
try:
self._import_operators()
logger.info("Worker started...")
await self._initialize_request_handlers()
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception("Encountered an error in worker: %s", e)
self._stop_requested = True
error = e
logger.info("worker store: %s", list(self._data_plane._tensor_store.keys()))
logger.info("Worker stopped...")
logger.info(
"Hosted Operators: %s Requests Received: %s Responses Sent: %s",
self._operators,
self._requests_received,
RemoteResponseSender.response_counts,
)
await self._request_plane.close()
self._data_plane.close()
if self._metrics_server:
self._metrics_server.should_exit = True
await self._metrics_server.shutdown()
return error
async def shutdown(self, signal):
logger.info("Received exit signal %s...", signal.name)
self._stop_requested = True
try:
if self._data_plane:
self._data_plane.close()
except Exception as e:
logger.exception("Failed to close the data plane: %s", e)
try:
if self._request_plane:
await self._request_plane.close()
except Exception as e:
logger.exception("Failed to close the request plane: %s", e)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
logger.info("Cancelling %s outstanding tasks", len(tasks))
[task.cancel() for task in tasks]
if self._triton_core:
self._triton_core.stop()
if self._metrics_server:
self._metrics_server.should_exit = True
await self._metrics_server.shutdown()
def _setup_metrics_server(self):
import uvicorn
from fastapi import FastAPI
from fastapi.responses import PlainTextResponse
app = FastAPI()
log_config = get_logger_config(
logger_name="uvicorn.error",
log_level=self._log_level,
log_file=self._log_file,
)
config = uvicorn.Config(
app,
port=self._metrics_port,
log_level=self._log_level,
log_config=log_config,
)
server = uvicorn.Server(config)
@app.get("/metrics", response_class=PlainTextResponse)
def metrics() -> str:
if self._triton_core:
return self._triton_core.metrics()
else:
return ""
return server
@staticmethod
def exception_handler(loop, context):
# get details of the exception
exception = context["exception"]
message = context["message"]
# log exception
logger.error(f"Task failed, msg={message}, exception={exception}")
async def _wait_for_tasks(self, loop):
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
try:
await asyncio.gather(*tasks, return_exceptions=True)
except asyncio.CancelledError as e:
logger.exception("Cancelled in task clean-up: %s", e)
except Exception as e:
logger.exception("Encountered an error in task clean-up: %s", e)
logger.info("Stopping the event loop")
loop.stop()
def start(self):
exit_condition = None
logger = get_logger(log_level=self._log_level, log_file=self._log_file)
logger.info(f"Starting Worker ==> {self._name}")
loop = asyncio.get_event_loop()
loop.set_exception_handler(Worker.exception_handler)
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
# Note: mypy has known issues inferring
# types of lambdas that include capturing loop variables
for sig in signals:
loop.add_signal_handler(
sig, lambda s=sig: asyncio.create_task(self.shutdown(s)) # type: ignore
)
serve_result = None
try:
if self._metrics_port:
serve_result = loop.create_task(self.serve())
self._metrics_server = self._setup_metrics_server()
assert self._metrics_server, "Unable to start metrics server"
loop.run_until_complete(self._metrics_server.serve())
else:
serve_result = loop.run_until_complete(self.serve())
except asyncio.CancelledError:
logger.info("Worker cancelled!")
finally:
loop.run_until_complete(self._wait_for_tasks(loop))
loop.close()
logger.info("Successfully shutdown worker.")
if isinstance(serve_result, asyncio.Task):
exit_condition = serve_result.result()
else:
exit_condition = serve_result
if exit_condition is not None:
sys.exit(1)
else:
sys.exit(0)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import subprocess
import time
from pathlib import Path
import pytest
import pytest_asyncio
from triton_distributed.icp.nats_request_plane import NatsServer
logger = logging.getLogger(__name__)
NATS_PORT = 4223
TEST_API_SERVER_MODEL_REPO_PATH = "integration/api_server/models"
def pytest_addoption(parser):
parser.addoption(
"--basetemp-permissions",
action="store",
help="Permissions of the base temporary directory used by tmp_path, as octal value. Examples: 700 (default), 750, 770",
)
@pytest.fixture(scope="session")
def log_dir(request, tmp_path_factory):
log_dir = tmp_path_factory.mktemp("logs")
try:
permissions = request.config.getoption("--basetemp-permissions")
except ValueError:
permissions = False
if permissions:
basetemp = request.config._tmp_path_factory.getbasetemp()
os.chmod(basetemp, int(permissions, 8))
os.chmod(log_dir, int(permissions, 8))
return log_dir
@pytest.fixture(scope="session")
def nats_server(log_dir):
server = NatsServer(log_dir=log_dir / "nats")
yield server
del server
@pytest.fixture(scope="session")
def api_server(log_dir):
command = [
"tritonserver",
"--model-store",
str(Path(__file__).parent.resolve() / TEST_API_SERVER_MODEL_REPO_PATH),
]
api_server_log_dir = log_dir / "api_server"
os.makedirs(api_server_log_dir, exist_ok=True)
with open(api_server_log_dir / "api_server.stdout.log", "wt") as output_:
with open(api_server_log_dir / "api_server.stderr.log", "wt") as output_err:
process = subprocess.Popen(
command, stdin=subprocess.DEVNULL, stdout=output_, stderr=output_err
)
time.sleep(10)
yield process
process.terminate()
process.wait()
print("Successfully cleaned-up T2 API server")
@pytest_asyncio.fixture
async def aio_benchmark(benchmark):
async def run_async_coroutine(func, *args, **kwargs):
return await func(*args, **kwargs)
def _wrapper(func, *args, **kwargs):
if asyncio.iscoroutinefunction(func):
@benchmark
def _():
future = asyncio.ensure_future(
run_async_coroutine(func, *args, **kwargs)
)
return asyncio.get_event_loop().run_until_complete(future)
else:
benchmark(func, *args, **kwargs)
return _wrapper
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import gc
import json
import queue
import threading
import traceback
import uuid
import triton_python_backend_utils as pb_utils
import ucp
from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.runtime.remote_operator import RemoteOperator
class TritonPythonModel:
"""
This model allows Triton to act like a api server for T3 ICP
"""
@staticmethod
def auto_complete_config(auto_complete_model_config):
inputs = [
{"name": "query", "data_type": "TYPE_STRING", "dims": [1]},
{
"name": "request_output_len",
"data_type": "TYPE_INT32",
"dims": [1],
},
]
outputs = [{"name": "output", "data_type": "TYPE_STRING", "dims": [-1]}]
# Store the model configuration as a dictionary.
config = auto_complete_model_config.as_dict()
input_names = []
output_names = []
for input in config["input"]:
input_names.append(input["name"])
for output in config["output"]:
output_names.append(output["name"])
# Add only missing inputs and output to the model configuration.
for input in inputs:
if input["name"] not in input_names:
auto_complete_model_config.add_input(input)
for output in outputs:
if output["name"] not in output_names:
auto_complete_model_config.add_output(output)
# We need to use decoupled transaction policy for saturating T3
auto_complete_model_config.set_model_transaction_policy(dict(decoupled=True))
# Disabling batching in Triton,
auto_complete_model_config.set_max_batch_size(0)
return auto_complete_model_config
async def _connect(self):
ucp.reset()
self._request_plane = NatsRequestPlane(self._request_plane_uri)
self._data_plane = UcpDataPlane()
self._data_plane.connect()
await self._request_plane.connect()
async def _disconnect(self, timeout):
self._data_plane.close(wait_for_release=timeout)
await self._request_plane.close()
async def _await_shutdown(self):
"""
Primary coroutine running on the engine event loop. This coroutine is responsible for
keeping the engine alive until a shutdown is requested.
"""
# first await the shutdown signal
while self._shutdown_event.is_set() is False:
await asyncio.sleep(5)
# Wait for the ongoing_requests
while self._ongoing_request_count > 0:
self.logger.log_info(
"[API Server] Awaiting remaining {} requests".format(
self._ongoing_request_count
)
)
await asyncio.sleep(5)
for task in asyncio.all_tasks(loop=self._loop):
if task is not asyncio.current_task():
task.cancel()
self.logger.log_info("[API Server] Shutdown complete")
def _create_task(self, coro):
"""
Creates a task on the event loop which is running on a separate thread.
"""
assert (
self._shutdown_event.is_set() is False
), "Cannot create tasks after shutdown has been requested"
return asyncio.run_coroutine_threadsafe(coro, self._loop)
def _event_loop(self, loop):
"""
Runs the engine's event loop on a separate thread.
"""
asyncio.set_event_loop(loop)
self._loop.run_until_complete(self._await_shutdown())
def initialize(self, args):
model_config = json.loads(args["model_config"])
self.logger = pb_utils.Logger
# Starting asyncio event loop to process the received requests asynchronously.
self._loop = asyncio.get_event_loop()
self._event_thread = threading.Thread(
target=self._event_loop, args=(self._loop,)
)
self._shutdown_event = asyncio.Event()
self._event_thread.start()
self._request_plane_uri = model_config["parameters"]["request_plane_uri"][
"string_value"
]
future = self._create_task(self._connect())
try:
_ = future.result(timeout=5)
except TimeoutError:
self.logger.log_error(
"The connection to T3 ICP took too long, cancelling the task..."
)
future.cancel()
except Exception as exc:
self.logger.log_error(
f"The connection to T3 ICP raised an exception: {exc!r}"
)
self._remote_worker_name = model_config["parameters"]["remote_worker_name"][
"string_value"
]
self._remote_operator = RemoteOperator(
self._remote_worker_name, self._request_plane, self._data_plane
)
# Starting the response thread. It allows API Server to keep making progress while
# response sender(s) are sending responses to server frontend.
self._response_queue = queue.Queue()
self._response_thread = threading.Thread(target=self.response_loop)
self._response_thread.start()
# Counter to keep track of ongoing request counts
self._ongoing_request_count = 0
for output_name in ["output"]:
setattr(
self,
output_name.lower() + "_dtype",
pb_utils.triton_string_to_numpy(
pb_utils.get_output_config_by_name(model_config, output_name)[
"data_type"
]
),
)
def response_loop(self):
while True:
item = self._response_queue.get()
# To signal shutdown a None item will be added to the queue.
if item is None:
break
response_sender, response, response_flag = item
del item
try:
response_sender.send(response, response_flag)
except Exception as e:
self.logger.log_error(
f"An error occurred while sending a response: {e}"
)
finally:
if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL:
self._ongoing_request_count -= 1
del response_sender
if self._ongoing_request_count == 0:
gc.collect()
def execute(self, requests):
for request in requests:
if request is not None:
self._create_task(self.remote_execute(request))
return None
async def remote_execute(self, request):
response_sender = request.get_response_sender()
self._ongoing_request_count += 1
query = pb_utils.get_input_tensor_by_name(request, "query").as_numpy()
request_output_len = pb_utils.get_input_tensor_by_name(
request, "request_output_len"
).as_numpy()
request_id = str(uuid.uuid4())
infer_request = self._remote_operator.create_request(
inputs={"query": query, "request_output_len": request_output_len},
request_id=request_id,
)
try:
async for response in await self._remote_operator.async_infer(
inference_request=infer_request
):
if response.error:
raise pb_utils.TritonModelException(response.error.message())
if not response.final:
output = response.outputs["output"]
output_value = output.to_bytes_array()
# Just forwarding query to the pre-processed input_ids
output_tensor = pb_utils.Tensor(
"output", output_value.astype(self.output_dtype)
)
inference_response = pb_utils.InferenceResponse(
output_tensors=[output_tensor]
)
self._response_queue.put_nowait(
(response_sender, inference_response, 0)
)
except Exception as e:
self.logger.log_error(
f"Failed running remote inference {traceback.print_exc()}"
)
raise pb_utils.TritonModelException(repr(e))
self._response_queue.put_nowait(
(response_sender, None, pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
)
return None
def finalize(self):
self.logger.log_info("[API Server] Issuing finalize to API Server")
future = self._create_task(self._disconnect(timeout=5))
try:
_ = future.result(timeout=7)
except TimeoutError:
self.logger.log_error(
"The connection to T3 ICP took too long, cancelling the task..."
)
future.cancel()
except Exception as exc:
self.logger.log_error(
f"The connection to T3 ICP raised an exception: {exc!r}"
)
self._shutdown_event.set()
# Shutdown the event thread.
if self._event_thread is not None:
self._event_thread.join()
self._event_thread = None
# Shutdown the response thread.
self._response_queue.put(None)
if self._response_thread is not None:
self._response_thread.join()
self._response_thread = None
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "mock_disaggregated_serving"
backend: "python"
max_batch_size: 0
model_transaction_policy {
decoupled: true
}
parameters: {
key: "remote_worker_name"
value: {
string_value: "mock_disaggregated_serving"
}
}
parameters: {
key: "request_plane_uri"
value: {
string_value: "nats://localhost:4223"
}
}
# Add more parameters as per requirement
instance_group [
{
count: 1
kind : KIND_CPU
}
]
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import numpy
from triton_distributed.runtime import Operator, RemoteInferenceRequest, RemoteOperator
class AddMultiplyDivide(Operator):
def __init__(
self,
name,
version,
request_plane,
data_plane,
parameters,
repository,
logger,
triton_core,
):
self._triton_core = triton_core
self._request_plane = request_plane
self._data_plane = data_plane
self._parameters = parameters
self._add_model = RemoteOperator("add", self._request_plane, self._data_plane)
self._multiply_model = RemoteOperator(
"multiply", self._request_plane, self._data_plane
)
self._divide_model = RemoteOperator(
"divide", self._request_plane, self._data_plane
)
self._logger = logger
async def execute(self, requests: list[RemoteInferenceRequest]):
self._logger.debug("in execute!")
for request in requests:
outputs = {}
self._logger.debug(request.inputs)
array = None
try:
array = numpy.from_dlpack(request.inputs["int64_input"])
except Exception:
self._logger.exception("Failed to retrieve inputs")
self._logger.debug(array)
response = [
response
async for response in await self._add_model.async_infer(
inputs={"int64_input": array}
)
][0]
self._logger.debug(response)
for output_name, output_value in response.outputs.items():
outputs[f"{response.model_name}_{output_name}"] = output_value
addition_output_partial = response.outputs["int64_output_partial"]
addition_output_total = response.outputs["int64_output_total"]
multiply_respnoses = self._multiply_model.async_infer(
inputs={"int64_input": addition_output_partial}, raise_on_error=False
)
divide_responses = self._divide_model.async_infer(
inputs={
"int64_input": addition_output_partial,
"int64_input_divisor": addition_output_total,
},
raise_on_error=False,
)
error = None
for result in asyncio.as_completed([multiply_respnoses, divide_responses]):
responses = await result
async for response in responses:
self._logger.debug(f"response! {response}")
self._logger.debug(f"error! {response.error}")
if response.error is not None:
error = response.error
break
for output_name, output_value in response.outputs.items():
outputs[f"{response.model_name}_{output_name}"] = output_value
if error is not None:
await request.response_sender().send(error=error, final=True)
else:
await request.response_sender().send(outputs=outputs, final=True)
for output in outputs.values():
del output
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
from triton_distributed.runtime import Operator, RemoteInferenceRequest
class Identity(Operator):
"""
This is a dummy workflow that sends a single input as an output.
"""
def __init__(
self,
name,
version,
request_plane,
data_plane,
params,
repository,
logger,
triton_core,
):
self._triton_core = triton_core
self._request_plane = request_plane
self._data_plane = data_plane
self._params = params
self._logger = logger
async def execute(self, requests: list[RemoteInferenceRequest]):
for request in requests:
try:
array = numpy.from_dlpack(request.inputs["input"])
except Exception as e:
self.logger.exception("Failed to retrieve inputs")
await request.response_sender().send(final=True, error=e)
return
self._logger.debug("Operator received inputs")
outputs: dict[str, numpy.ndarray] = {"output": array}
store_outputs_in_response = False
if "store_outputs_in_response" in self._params:
store_outputs_in_response = self._params["store_outputs_in_response"]
store_outputs_in_response_set = set()
if store_outputs_in_response:
store_outputs_in_response_set.add("output")
await request.response_sender().send(
outputs=outputs,
final=True,
store_outputs_in_response=store_outputs_in_response_set,
)
for output in outputs.values():
del output
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tritonserver import TritonError
from triton_distributed.runtime.operator import Operator
from triton_distributed.runtime.remote_operator import RemoteOperator
from triton_distributed.runtime.remote_request import RemoteInferenceRequest
class MockDisaggregatedServing(Operator):
def __init__(
self,
name,
version,
request_plane,
data_plane,
params,
repository,
logger,
triton_core,
):
self._triton_core = triton_core
self._request_plane = request_plane
self._data_plane = data_plane
self._params = params
self._preprocessing_model = RemoteOperator(
"preprocessing", self._request_plane, self._data_plane
)
self._context_model = RemoteOperator(
"context", self._request_plane, self._data_plane
)
self._generate_model = RemoteOperator(
"generation", self._request_plane, self._data_plane
)
self._postprocessing_model = RemoteOperator(
"postprocessing", self._request_plane, self._data_plane
)
self._logger = logger
async def _run_generate(self, context_response, response_sender):
error = None
generate_inputs = {}
if not error:
for output_name in ["KV_CACHE", "REQUEST_OUTPUT_LEN"]:
if output_name not in context_response.outputs.keys():
error_msg = f"Expected '{output_name}' as output in llm model response, Got outputs {context_response.outputs.keys()}"
self._logger.error(error_msg)
self._logger.debug(f"context_response: {context_response}")
error = TritonError(error_msg)
else:
generate_inputs[output_name] = context_response.outputs[output_name]
postproc_result = []
generate_responses = []
if not error:
try:
# TODO: Run post-processing in parallel with generate
async for response in await self._generate_model.async_infer(
inputs=generate_inputs
):
generate_responses.append(response)
self._logger.debug(f"Received response {response}")
if not generate_responses[-1].final:
postproc_result.append(
await self._run_postprocessing(
generate_responses[-1], response_sender, final=False
)
)
except Exception as e:
error = TritonError(repr(e))
self._logger.exception("Failed to run post-processing")
for generate_response in generate_responses:
for tensor in generate_response.outputs.values():
del tensor
return postproc_result
async def _run_postprocessing(self, llm_response, response_sender, final):
self._logger.debug(f"going to run_post_processing final={final}")
postproc_inputs = {}
for output_name in ["OUTPUT_IDS", "SEQUENCE_LENGTH"]:
if output_name not in llm_response.outputs.keys():
error_msg = f"Expected '{output_name}' as output in llm model response, Got outputs {llm_response.outputs.items()}"
self._logger.error(error_msg)
self._logger.debug(f"llm_response: {llm_response}")
raise Exception(error_msg)
else:
postproc_inputs[output_name] = llm_response.outputs[output_name]
outputs = {}
postproc_responses = []
# TODO: Run post-processing in parallel with generate
self._logger.debug(f"Sending request to post-process {postproc_inputs}")
sending = []
async for response in await self._postprocessing_model.async_infer(
inputs=postproc_inputs
):
self._logger.debug(f"Received response {response}")
self._logger.debug(f"Got response from post-process {response}")
postproc_responses.append(response)
outputs["output"] = postproc_responses[-1].outputs["OUTPUT"]
sending.append(await response_sender().send(outputs=outputs, final=False))
return sending
async def execute(self, requests: list[RemoteInferenceRequest]):
self._logger.debug("in execute!")
error = None
for request in requests:
"""
Pre-processing
"""
preproc_responses = []
async for response in await self._preprocessing_model.async_infer(
inference_request=request
):
preproc_responses.append(response)
if not error and len(preproc_responses) != 1:
error_msg = f"Expected exactly 1 response from preprocessing model, Got {len(preproc_responses)}"
self._logger.error(error_msg)
error = TritonError(error_msg)
context_inputs = {}
if not error:
for output_name in ["INPUT_IDS", "INPUT_LENGTH", "REQUEST_OUTPUT_LEN"]:
if output_name not in preproc_responses[0].outputs.keys():
error_msg = f"Expected '{output_name}' as output in preprocessing model response, Got outputs {response.outputs.keys()}"
self._logger.error(error_msg)
error = TritonError(error_msg)
else:
context_inputs[output_name] = preproc_responses[0].outputs[
output_name
]
"""
Prefill
"""
context_responses = []
postproc_result = []
if not error:
async for response in await self._context_model.async_infer(
inputs=context_inputs
):
context_responses.append(response)
if not error:
if not error and len(context_responses) != 1:
error_msg = f"Expected exactly 1 response from context model, Got {len(context_responses)}"
self._logger.error(error_msg)
error = TritonError(error_msg)
else:
postproc_result.append(
self._run_postprocessing(
context_responses[0], request.response_sender, final=False
)
)
"""
Generate
"""
if not error:
postproc_result.append(
self._run_generate(context_responses[0], request.response_sender)
)
for result in postproc_result:
try:
await result
except Exception as e:
self._logger.exception(
f"Failed getting response from post-processing {result}: {e}"
)
error = TritonError(repr(e))
for tensor in preproc_responses[0].outputs.values():
del tensor
for tensor in context_responses[0].outputs.values():
del tensor
if error:
await request.response_sender().send(error=error, final=True)
else:
await request.response_sender().send(final=True)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import numpy as np
import triton_python_backend_utils as pb_utils
try:
import cupy
except Exception:
cupy = None
class TritonPythonModel:
@staticmethod
def auto_complete_config(auto_complete_model_config):
inputs = []
outputs = []
dims = [-1, -1]
optional = True
for data_type in ["type_int64"]:
type_name = data_type.split("_")[1].lower()
input_name = f"{type_name}_input"
output_name_1 = f"{type_name}_output_total"
output_name_2 = f"{type_name}_output_partial"
inputs.append(
{
"name": input_name,
"data_type": data_type,
"dims": dims,
"optional": optional,
}
)
outputs.append(
{"name": output_name_1, "data_type": data_type, "dims": dims}
)
outputs.append(
{"name": output_name_2, "data_type": data_type, "dims": dims}
)
outputs.append(
{"name": "output_parameters", "data_type": "TYPE_STRING", "dims": [1]}
)
for input_ in inputs:
auto_complete_model_config.add_input(input_)
for output in outputs:
auto_complete_model_config.add_output(output)
auto_complete_model_config.set_max_batch_size(0)
return auto_complete_model_config
def initialize(self, args):
self._model_config = json.loads(args["model_config"])
self._request_gpu_memory = False
if "parameters" in self._model_config:
parameters = self._model_config["parameters"]
if (
"request_gpu_memory" in parameters
and parameters["request_gpu_memory"]["string_value"] == "True"
):
self._request_gpu_memory = True
def execute(self, requests):
responses = []
for request in requests:
output_tensors = []
for input_tensor in request.inputs():
input_value = input_tensor.as_numpy()
output_value_partial = np.array([[x.sum() for x in input_value]])
output_value_total = np.array([[input_value.sum()]])
if self._request_gpu_memory:
output_value_partial = cupy.array(output_value_partial)
output_value_total = cupy.array(output_value_total)
output_tensor = pb_utils.Tensor.from_dlpack(
input_tensor.name().replace("input", "output_partial"),
output_value_partial,
)
output_tensors.append(output_tensor)
output_tensor = pb_utils.Tensor.from_dlpack(
input_tensor.name().replace("input", "output_total"),
output_value_total,
)
output_tensors.append(output_tensor)
else:
output_tensor = pb_utils.Tensor(
input_tensor.name().replace("input", "output_partial"),
output_value_partial,
)
output_tensors.append(output_tensor)
output_tensor = pb_utils.Tensor(
input_tensor.name().replace("input", "output_total"),
output_value_total,
)
output_tensors.append(output_tensor)
output_parameters = np.array([request.parameters()]).astype(np.object_)
output_tensors.append(
pb_utils.Tensor("output_parameters", output_parameters)
)
responses.append(
pb_utils.InferenceResponse(
output_tensors=output_tensors,
)
)
return responses
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
backend: "python"
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
import triton_python_backend_utils as pb_utils
class TritonPythonModel:
def initialize(self, args):
model_config = json.loads(args["model_config"])
self._context_delay = (
int(model_config["parameters"]["context_delay_ms"]["string_value"])
) / 1000
for output_name in [
"KV_CACHE",
"OUTPUT_IDS",
"SEQUENCE_LENGTH",
"REQUEST_OUTPUT_LEN",
]:
setattr(
self,
output_name.lower() + "_dtype",
pb_utils.triton_string_to_numpy(
pb_utils.get_output_config_by_name(model_config, output_name)[
"data_type"
]
),
)
def execute(self, requests):
responses = []
for idx, request in enumerate(requests):
# Get input tensors
input_ids = pb_utils.get_input_tensor_by_name(
request, "INPUT_IDS"
).as_numpy()
input_lengths = pb_utils.get_input_tensor_by_name(
request, "INPUT_LENGTH"
).as_numpy()
request_output_len = pb_utils.get_input_tensor_by_name(
request, "REQUEST_OUTPUT_LEN"
).as_numpy()
time.sleep(self._context_delay)
# Create output tensors. You need pb_utils.Tensor
# objects to create pb_utils.InferenceResponse.
kv_cache_tensor = pb_utils.Tensor(
"KV_CACHE", input_ids.astype(self.kv_cache_dtype)
)
output_ids_tensor = pb_utils.Tensor(
"OUTPUT_IDS", input_ids.astype(self.output_ids_dtype)
)
sequence_length_tensor = pb_utils.Tensor(
"SEQUENCE_LENGTH", input_lengths.astype(self.sequence_length_dtype)
)
request_output_len_tensor = pb_utils.Tensor(
"REQUEST_OUTPUT_LEN", request_output_len
)
inference_response = pb_utils.InferenceResponse(
output_tensors=[
kv_cache_tensor,
output_ids_tensor,
sequence_length_tensor,
request_output_len_tensor,
]
)
responses.append(inference_response)
return responses
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Emulates the tensorrt_llm config from:
# https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt
name: "context"
backend: "python"
max_batch_size: 0
parameters: {
key: "context_delay_ms"
value: {
string_value: "1000"
}
}
input [
{
name: "INPUT_IDS"
data_type: TYPE_INT32
dims: [ -1 ]
},
{
name: "INPUT_LENGTH"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "REQUEST_OUTPUT_LEN"
data_type: TYPE_INT32
dims: [ 1 ]
}
# Add more inputs as per requirement.
# For simplicity only sticking with these
# inputs for preprocessing.
]
output [
# Section of the first request that returns the first token.
# These will be handed over directly to the post-processor
{
name: "OUTPUT_IDS"
data_type: TYPE_INT32
dims: [ -1 ]
},
{
name: "SEQUENCE_LENGTH"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "REQUEST_OUTPUT_LEN"
data_type: TYPE_INT32
dims: [ 1 ]
},
# Section of the second part of handover to the generate stage
{
# TODO: revisit how kv cache is being exposed to generate worker.
name: "KV_CACHE"
data_type: TYPE_INT32
dims: [ -1 ]
}
# Add more outputs as per requirement.
# For simplicity only sticking with these
# outputs for preprocessing.
]
# Add more parameters as per requirement
instance_group [
{
count: 1
kind : KIND_CPU
}
]
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import numpy as np
import triton_python_backend_utils as pb_utils
try:
import cupy
except Exception:
cupy = None
class TritonPythonModel:
@staticmethod
def auto_complete_config(auto_complete_model_config):
inputs = []
outputs = []
dims = [-1, -1]
optional = True
for data_type in ["type_int64"]:
type_name = data_type.split("_")[1].lower()
input_name = f"{type_name}_input"
output_name = "fp64_output_partial"
inputs.append(
{
"name": input_name,
"data_type": data_type,
"dims": dims,
"optional": optional,
}
)
outputs.append({"name": output_name, "data_type": data_type, "dims": dims})
input_name = f"{type_name}_input_divisor"
inputs.append(
{
"name": input_name,
"data_type": data_type,
"dims": dims,
"optional": optional,
}
)
outputs.append(
{"name": "output_parameters", "data_type": "TYPE_STRING", "dims": [1]}
)
for input_ in inputs:
auto_complete_model_config.add_input(input_)
for output in outputs:
auto_complete_model_config.add_output(output)
auto_complete_model_config.set_max_batch_size(0)
return auto_complete_model_config
def initialize(self, args):
self._model_config = json.loads(args["model_config"])
self._request_gpu_memory = False
if "parameters" in self._model_config:
parameters = self._model_config["parameters"]
if (
"request_gpu_memory" in parameters
and parameters["request_gpu_memory"]["string_value"] == "True"
):
self._request_gpu_memory = True
def execute(self, requests):
responses = []
for request in requests:
output_tensors = []
divisor = pb_utils.get_input_tensor_by_name(request, "int64_input_divisor")
divisor = divisor.as_numpy()[0][0]
dividends = pb_utils.get_input_tensor_by_name(request, "int64_input")
dividends = dividends.as_numpy()
output_value = np.array([np.divide(dividends, divisor)])
if self._request_gpu_memory:
output_value = cupy.array(output_value)
output_tensor = pb_utils.Tensor.from_dlpack(
"fp64_output_partial", output_value
)
else:
output_tensor = pb_utils.Tensor("fp64_output_partial", output_value)
output_tensors.append(output_tensor)
output_parameters = np.array([request.parameters()]).astype(np.object_)
output_tensors.append(
pb_utils.Tensor("output_parameters", output_parameters)
)
responses.append(
pb_utils.InferenceResponse(
output_tensors=output_tensors,
)
)
return responses
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
backend: "python"
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import threading
import time
import numpy as np
import triton_python_backend_utils as pb_utils
class TritonPythonModel:
def initialize(self, args):
model_config = json.loads(args["model_config"])
self._output_token_latency = (
int(model_config["parameters"]["inter_token_latency_ms"]["string_value"])
) / 1000
# You must parse model_config. JSON string is not parsed here
self.model_config = model_config = json.loads(args["model_config"])
using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
model_config
)
if not using_decoupled:
raise pb_utils.TritonModelException(
"""the model `{}` can generate any number of responses per request,
enable decoupled transaction policy in model configuration to
serve this model""".format(
args["model_name"]
)
)
for output_name in ["OUTPUT_IDS", "SEQUENCE_LENGTH"]:
setattr(
self,
output_name.lower() + "_dtype",
pb_utils.triton_string_to_numpy(
pb_utils.get_output_config_by_name(model_config, output_name)[
"data_type"
]
),
)
# To keep track of response threads so that we can delay
# the finalizing the model until all response threads
# have completed.
self.inflight_thread_count = 0
self.inflight_thread_count_lck = threading.Lock()
def response_thread(self, response_sender, kv_cache, request_output_len):
for idx in range(request_output_len):
time.sleep(self._output_token_latency)
output_ids_tensor = pb_utils.Tensor(
"OUTPUT_IDS", kv_cache.astype(self.output_ids_dtype)
)
sequence_length = np.array([kv_cache.size])
sequence_length_tensor = pb_utils.Tensor(
"SEQUENCE_LENGTH", sequence_length.astype(self.sequence_length_dtype)
)
response = pb_utils.InferenceResponse(
output_tensors=[output_ids_tensor, sequence_length_tensor]
)
response_sender.send(response)
# We must close the response sender to indicate to Triton that we are
# done sending responses for the corresponding request. We can't use the
# response sender after closing it. The response sender is closed by
# setting the TRITONSERVER_RESPONSE_COMPLETE_FINAL.
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
with self.inflight_thread_count_lck:
self.inflight_thread_count -= 1
def execute(self, requests):
for idx, request in enumerate(requests):
# Get input tensors
kv_cache = pb_utils.get_input_tensor_by_name(request, "KV_CACHE").as_numpy()
request_output_len = pb_utils.get_input_tensor_by_name(
request, "REQUEST_OUTPUT_LEN"
).as_numpy()
# Start a separate thread to send the responses for the request. The
# sending back the responses is delegated to this thread.
thread = threading.Thread(
target=self.response_thread,
args=(
requests[0].get_response_sender(),
kv_cache,
request_output_len[0],
),
)
# A model using decoupled transaction policy is not required to send all
# responses for the current request before returning from the execute.
# To demonstrate the flexibility of the decoupled API, we are running
# response thread entirely independent of the execute thread.
thread.daemon = True
with self.inflight_thread_count_lck:
self.inflight_thread_count += 1
thread.start()
return None
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Emulates the tensorrt_llm config from:
# https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt
name: "generation"
backend: "python"
max_batch_size: 0
model_transaction_policy {
decoupled: true
}
parameters: {
key: "inter_token_latency_ms"
value: {
string_value: "1000"
}
}
input [
{
# TODO: revisit how kv cache is being exposed to generate worker.
name: "KV_CACHE"
data_type: TYPE_INT32
dims: [ -1 ]
},
{
name: "REQUEST_OUTPUT_LEN"
data_type: TYPE_INT32
dims: [ 1 ]
}
# Add more inputs as per requirement.
# For simplicity only sticking with these
# inputs for preprocessing.
]
output [
{
name: "OUTPUT_IDS"
data_type: TYPE_INT32
dims: [ -1 ]
},
{
name: "SEQUENCE_LENGTH"
data_type: TYPE_INT32
dims: [ 1 ]
}
# Add more outputs as per requirement.
# For simplicity only sticking with these
# outputs for preprocessing.
]
# Add more parameters as per requirement
instance_group [
{
count: 1
kind : KIND_CPU
}
]
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import numpy as np
import triton_python_backend_utils as pb_utils
try:
import cupy
except Exception:
cupy = None
class TritonPythonModel:
@staticmethod
def auto_complete_config(auto_complete_model_config):
inputs = []
outputs = []
dims = [-1, -1]
optional = True
config = auto_complete_model_config.as_dict()
for data_type in pb_utils.TRITON_STRING_TO_NUMPY.keys():
type_name = data_type.split("_")[1].lower()
input_name = f"{type_name}_input"
output_name = f"{type_name}_output"
inputs.append(
{
"name": input_name,
"data_type": data_type,
"dims": dims,
"optional": optional,
}
)
outputs.append({"name": output_name, "data_type": data_type, "dims": dims})
outputs.append(
{"name": "output_parameters", "data_type": "TYPE_STRING", "dims": [1]}
)
for input_ in inputs:
auto_complete_model_config.add_input(input_)
for output in outputs:
auto_complete_model_config.add_output(output)
auto_complete_model_config.set_max_batch_size(0)
if "decoupled" in config["parameters"]:
if config["parameters"]["decoupled"]["string_value"] == "True":
auto_complete_model_config.set_model_transaction_policy(
{"decoupled": True}
)
return auto_complete_model_config
def initialize(self, args):
self._model_config = json.loads(args["model_config"])
self._decoupled = self._model_config.get("model_transaction_policy", {}).get(
"decoupled"
)
self._request_gpu_memory = False
if "parameters" in self._model_config:
parameters = self._model_config["parameters"]
if (
"request_gpu_memory" in parameters
and parameters["request_gpu_memory"]["string_value"] == "True"
):
self._request_gpu_memory = True
def execute_decoupled(self, requests):
for request in requests:
sender = request.get_response_sender()
output_tensors = []
for input_tensor in request.inputs():
input_value = input_tensor.as_numpy()
output_tensor = pb_utils.Tensor(
input_tensor.name().replace("input", "output"), input_value
)
output_tensors.append(output_tensor)
sender.send(pb_utils.InferenceResponse(output_tensors=output_tensors))
sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
return None
def execute(self, requests):
if self._decoupled:
return self.execute_decoupled(requests)
responses = []
for request in requests:
output_tensors = []
for input_tensor in request.inputs():
input_value = input_tensor.as_numpy()
if self._request_gpu_memory:
input_value = cupy.array(input_value)
output_tensor = pb_utils.Tensor.from_dlpack(
input_tensor.name().replace("input", "output"), input_value
)
else:
output_tensor = pb_utils.Tensor(
input_tensor.name().replace("input", "output"), input_value
)
output_tensors.append(output_tensor)
output_parameters = np.array([request.parameters()]).astype(np.object_)
output_tensors.append(
pb_utils.Tensor("output_parameters", output_parameters)
)
responses.append(
pb_utils.InferenceResponse(
output_tensors=output_tensors,
)
)
return responses
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
backend: "python"
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import numpy as np
import triton_python_backend_utils as pb_utils
try:
import cupy
except Exception:
cupy = None
class TritonPythonModel:
@staticmethod
def auto_complete_config(auto_complete_model_config):
inputs = []
outputs = []
dims = [-1, -1]
optional = True
for data_type in ["type_int64"]:
type_name = data_type.split("_")[1].lower()
input_name = f"{type_name}_input"
output_name = f"{type_name}_output_total"
inputs.append(
{
"name": input_name,
"data_type": data_type,
"dims": dims,
"optional": optional,
}
)
outputs.append({"name": output_name, "data_type": data_type, "dims": dims})
outputs.append(
{"name": "output_parameters", "data_type": "TYPE_STRING", "dims": [1]}
)
for input_ in inputs:
auto_complete_model_config.add_input(input_)
for output in outputs:
auto_complete_model_config.add_output(output)
auto_complete_model_config.set_max_batch_size(0)
return auto_complete_model_config
def initialize(self, args):
self._model_config = json.loads(args["model_config"])
self._request_gpu_memory = False
if "parameters" in self._model_config:
parameters = self._model_config["parameters"]
if (
"request_gpu_memory" in parameters
and parameters["request_gpu_memory"]["string_value"] == "True"
):
self._request_gpu_memory = True
def execute(self, requests):
responses = []
for request in requests:
output_tensors = []
for input_tensor in request.inputs():
input_value = input_tensor.as_numpy()
output_value = np.array([[x.prod() for x in input_value]])
if self._request_gpu_memory:
output_value = cupy.array(output_value)
output_tensor = pb_utils.Tensor.from_dlpack(
input_tensor.name().replace("input", "output_total"),
output_value,
)
else:
output_tensor = pb_utils.Tensor(
input_tensor.name().replace("input", "output_total"),
output_value,
)
output_tensors.append(output_tensor)
output_parameters = np.array([request.parameters()]).astype(np.object_)
output_tensors.append(
pb_utils.Tensor("output_parameters", output_parameters)
)
responses.append(
pb_utils.InferenceResponse(
output_tensors=output_tensors,
)
)
return responses
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment