"...git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "99cc11e6212cf92be492c0ca49e760f7c9ca57d8"
Commit e6c12674 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

test: Add remaining worker tests (#28)

parent 5a6e57c8
......@@ -46,6 +46,7 @@ skip = ["build"]
[tool.pytest.ini_options]
minversion = "8.0"
tmp_path_retention_policy = "failed"
# NOTE
# We ignore model.py explcitly here to avoid mypy errors with duplicate modules
......
......@@ -15,6 +15,7 @@
import asyncio
import logging
import os
import subprocess
import time
......@@ -28,22 +29,46 @@ logger = logging.getLogger(LOGGER_NAME)
NATS_PORT = 4223
TEST_API_SERVER_MODEL_REPO_PATH = (
"/workspace/worker/python/tests/integration/api_server/models"
"/workspace/worker/tests/python/integration/api_server/models"
)
def pytest_addoption(parser):
parser.addoption(
"--basetemp-permissions",
action="store",
help="Permissions of the base temporary directory used by tmp_path, as octal value. Examples: 700 (default), 750, 770",
)
@pytest.fixture(scope="session")
def log_dir(request, tmp_path_factory):
log_dir = tmp_path_factory.mktemp("logs")
try:
permissions = request.config.getoption("--basetemp-permissions")
except ValueError:
permissions = False
if permissions:
basetemp = request.config._tmp_path_factory.getbasetemp()
os.chmod(basetemp, int(permissions, 8))
os.chmod(log_dir, int(permissions, 8))
return log_dir
@pytest.fixture(scope="session")
def nats_server():
server = NatsServer()
def nats_server(log_dir):
server = NatsServer(log_dir=log_dir / "nats")
yield server
del server
@pytest.fixture(scope="session")
def api_server():
def api_server(log_dir):
command = ["tritonserver", "--model-store", str(TEST_API_SERVER_MODEL_REPO_PATH)]
with open("api_server.stdout.log", "wt") as output_:
with open("api_server.stderr.log", "wt") as output_err:
api_server_log_dir = log_dir / "api_server"
os.makedirs(api_server_log_dir, exist_ok=True)
with open(api_server_log_dir / "api_server.stdout.log", "wt") as output_:
with open(api_server_log_dir / "api_server.stderr.log", "wt") as output_err:
process = subprocess.Popen(
command, stdin=subprocess.DEVNULL, stdout=output_, stderr=output_err
)
......
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import asyncio
import gc
import json
import queue
import threading
import traceback
import uuid
import triton_python_backend_utils as pb_utils
import ucp
from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.worker.remote_operator import RemoteOperator
class TritonPythonModel:
"""
This model allows Triton to act like a api server for T3 ICP
"""
@staticmethod
def auto_complete_config(auto_complete_model_config):
inputs = [
{"name": "query", "data_type": "TYPE_STRING", "dims": [1]},
{
"name": "request_output_len",
"data_type": "TYPE_INT32",
"dims": [1],
},
]
outputs = [{"name": "output", "data_type": "TYPE_STRING", "dims": [-1]}]
# Store the model configuration as a dictionary.
config = auto_complete_model_config.as_dict()
input_names = []
output_names = []
for input in config["input"]:
input_names.append(input["name"])
for output in config["output"]:
output_names.append(output["name"])
# Add only missing inputs and output to the model configuration.
for input in inputs:
if input["name"] not in input_names:
auto_complete_model_config.add_input(input)
for output in outputs:
if output["name"] not in output_names:
auto_complete_model_config.add_output(output)
# We need to use decoupled transaction policy for saturating T3
auto_complete_model_config.set_model_transaction_policy(dict(decoupled=True))
# Disabling batching in Triton,
auto_complete_model_config.set_max_batch_size(0)
return auto_complete_model_config
async def _connect(self):
ucp.reset()
self._request_plane = NatsRequestPlane(self._request_plane_uri)
self._data_plane = UcpDataPlane()
self._data_plane.connect()
await self._request_plane.connect()
async def _disconnect(self, timeout):
self._data_plane.close(wait_for_release=timeout)
await self._request_plane.close()
async def _await_shutdown(self):
"""
Primary coroutine running on the engine event loop. This coroutine is responsible for
keeping the engine alive until a shutdown is requested.
"""
# first await the shutdown signal
while self._shutdown_event.is_set() is False:
await asyncio.sleep(5)
# Wait for the ongoing_requests
while self._ongoing_request_count > 0:
self.logger.log_info(
"[API Server] Awaiting remaining {} requests".format(
self._ongoing_request_count
)
)
await asyncio.sleep(5)
for task in asyncio.all_tasks(loop=self._loop):
if task is not asyncio.current_task():
task.cancel()
self.logger.log_info("[API Server] Shutdown complete")
def _create_task(self, coro):
"""
Creates a task on the event loop which is running on a separate thread.
"""
assert (
self._shutdown_event.is_set() is False
), "Cannot create tasks after shutdown has been requested"
return asyncio.run_coroutine_threadsafe(coro, self._loop)
def _event_loop(self, loop):
"""
Runs the engine's event loop on a separate thread.
"""
asyncio.set_event_loop(loop)
self._loop.run_until_complete(self._await_shutdown())
def initialize(self, args):
model_config = json.loads(args["model_config"])
self.logger = pb_utils.Logger
# Starting asyncio event loop to process the received requests asynchronously.
self._loop = asyncio.get_event_loop()
self._event_thread = threading.Thread(
target=self._event_loop, args=(self._loop,)
)
self._shutdown_event = asyncio.Event()
self._event_thread.start()
self._request_plane_uri = model_config["parameters"]["request_plane_uri"][
"string_value"
]
future = self._create_task(self._connect())
try:
_ = future.result(timeout=5)
except TimeoutError:
self.logger.log_error(
"The connection to T3 ICP took too long, cancelling the task..."
)
future.cancel()
except Exception as exc:
self.logger.log_error(
f"The connection to T3 ICP raised an exception: {exc!r}"
)
self._remote_worker_name = model_config["parameters"]["remote_worker_name"][
"string_value"
]
self._remote_operator = RemoteOperator(
self._remote_worker_name, 1, self._request_plane, self._data_plane
)
# Starting the response thread. It allows API Server to keep making progress while
# response sender(s) are sending responses to server frontend.
self._response_queue = queue.Queue()
self._response_thread = threading.Thread(target=self.response_loop)
self._response_thread.start()
# Counter to keep track of ongoing request counts
self._ongoing_request_count = 0
for output_name in ["output"]:
setattr(
self,
output_name.lower() + "_dtype",
pb_utils.triton_string_to_numpy(
pb_utils.get_output_config_by_name(model_config, output_name)[
"data_type"
]
),
)
def response_loop(self):
while True:
item = self._response_queue.get()
# To signal shutdown a None item will be added to the queue.
if item is None:
break
response_sender, response, response_flag = item
del item
try:
response_sender.send(response, response_flag)
except Exception as e:
self.logger.log_error(
f"An error occurred while sending a response: {e}"
)
finally:
if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL:
self._ongoing_request_count -= 1
del response_sender
if self._ongoing_request_count == 0:
gc.collect()
def execute(self, requests):
for request in requests:
if request is not None:
self._create_task(self.remote_execute(request))
return None
async def remote_execute(self, request):
response_sender = request.get_response_sender()
self._ongoing_request_count += 1
query = pb_utils.get_input_tensor_by_name(request, "query").as_numpy()
request_output_len = pb_utils.get_input_tensor_by_name(
request, "request_output_len"
).as_numpy()
request_id = str(uuid.uuid4())
infer_request = self._remote_operator.create_request(
inputs={"query": query, "request_output_len": request_output_len},
request_id=request_id,
)
try:
async for response in await self._remote_operator.async_infer(
inference_request=infer_request
):
if response.error:
raise pb_utils.TritonModelException(response.error.message())
if not response.final:
output = response.outputs["output"]
output_value = output.to_bytes_array()
# Just forwarding query to the pre-processed input_ids
output_tensor = pb_utils.Tensor(
"output", output_value.astype(self.output_dtype)
)
inference_response = pb_utils.InferenceResponse(
output_tensors=[output_tensor]
)
self._response_queue.put_nowait(
(response_sender, inference_response, 0)
)
except Exception as e:
self.logger.log_error(
f"Failed running remote inference {traceback.print_exc()}"
)
raise pb_utils.TritonModelException(repr(e))
self._response_queue.put_nowait(
(response_sender, None, pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
)
return None
def finalize(self):
self.logger.log_info("[API Server] Issuing finalize to API Server")
future = self._create_task(self._disconnect(timeout=5))
try:
_ = future.result(timeout=7)
except TimeoutError:
self.logger.log_error(
"The connection to T3 ICP took too long, cancelling the task..."
)
future.cancel()
except Exception as exc:
self.logger.log_error(
f"The connection to T3 ICP raised an exception: {exc!r}"
)
self._shutdown_event.set()
# Shutdown the event thread.
if self._event_thread is not None:
self._event_thread.join()
self._event_thread = None
# Shutdown the response thread.
self._response_queue.put(None)
if self._response_thread is not None:
self._response_thread.join()
self._response_thread = None
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
name: "mock_disaggregated_serving"
backend: "python"
max_batch_size: 0
model_transaction_policy {
decoupled: true
}
parameters: {
key: "remote_worker_name"
value: {
string_value: "mock_disaggregated_serving"
}
}
parameters: {
key: "request_plane_uri"
value: {
string_value: "nats://localhost:4223"
}
}
# Add more parameters as per requirement
instance_group [
{
count: 1
kind : KIND_CPU
}
]
......@@ -15,10 +15,8 @@
import asyncio
import logging
import shutil
import sys
from multiprocessing import Process
from pathlib import Path
import cupy
import numpy
......@@ -61,7 +59,7 @@ pytestmark = pytest.mark.pre_merge
@pytest.fixture
def workers(request):
def workers(request, log_dir):
operator_configs = {}
store_outputs_in_response = request.getfixturevalue("store_outputs_in_response")
......@@ -90,15 +88,13 @@ def workers(request):
worker_configs = []
test_log_directory = Path(__file__).with_suffix("")
if test_log_directory.exists():
shutil.rmtree(test_log_directory, ignore_errors=True)
test_log_directory.mkdir()
test_log_dir = log_dir / request.node.name
test_log_dir.mkdir(parents=True, exist_ok=True)
# We will instantiate a worker for each operator
for name, operator_config in operator_configs.items():
# Set the logging directory
log_dir = test_log_directory / name
worker_log_dir = test_log_dir / name
worker_configs.append(
WorkerConfig(
request_plane=NatsRequestPlane,
......@@ -108,8 +104,8 @@ def workers(request):
{"request_plane_uri": f"nats://localhost:{NATS_PORT}"},
),
log_level=TRITON_LOG_LEVEL,
log_dir=str(log_dir),
triton_log_path=str(log_dir / TRITON_LOG_FILE),
log_dir=str(worker_log_dir),
triton_log_path=str(worker_log_dir / TRITON_LOG_FILE),
operators=[operator_config],
)
)
......
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import sys
import uuid
from multiprocessing import Process
import cupy
import numpy
import pytest
import ucp
from cupy_backends.cuda.api.runtime import CUDARuntimeError
from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.worker.deployment import Deployment
from triton_distributed.worker.log_formatter import LOGGER_NAME
from triton_distributed.worker.operator import OperatorConfig
from triton_distributed.worker.remote_operator import RemoteOperator
from triton_distributed.worker.worker import WorkerConfig
NATS_PORT = 4223
MODEL_REPOSITORY = (
"/workspace/worker/tests/python/integration/operators/triton_core_models"
)
OPERATORS_REPOSITORY = "/workspace/worker/tests/python/integration/operators"
TRITON_LOG_FILE = "triton.log"
TRITON_LOG_LEVEL = 6
logger = logging.getLogger(LOGGER_NAME)
# Run cupy's cuda.is_available once to
# avoid the exception hitting runtime code.
try:
if cupy.cuda.is_available():
pass
else:
print("CUDA not available.")
except CUDARuntimeError:
print("CUDA not available")
# TODO
# Decide if this should be
# pre merge, nightly, or weekly
pytestmark = pytest.mark.pre_merge
@pytest.fixture
def workers(request, log_dir, number_workers=10):
# Add configs for identity operator
operator_name = "identity"
operator_config = OperatorConfig(
name=operator_name,
implementation="identity:Identity",
version=1,
max_inflight_requests=10,
repository=OPERATORS_REPOSITORY,
)
worker_configs = []
test_log_dir = log_dir / request.node.name
test_log_dir.mkdir(parents=True, exist_ok=True)
for i in range(number_workers):
# Set the logging directory
worker_log_dir = test_log_dir / (operator_name + "_" + str(i))
worker_configs.append(
WorkerConfig(
request_plane=NatsRequestPlane,
data_plane=UcpDataPlane,
request_plane_args=(
[],
{"request_plane_uri": f"nats://localhost:{NATS_PORT}"},
),
log_level=TRITON_LOG_LEVEL,
log_dir=str(worker_log_dir),
triton_log_path=str(worker_log_dir / TRITON_LOG_FILE),
operators=[operator_config],
)
)
worker_deployment = Deployment(worker_configs)
worker_deployment.start()
yield worker_deployment
worker_deployment.shutdown()
async def post_requests(num_requests, num_targets):
"""
Posts requests until the number of
workers that respond is equal to the number of targets
after that - only sends requests to one of the targets
"""
ucp.reset()
timeout = 5
data_plane = UcpDataPlane()
data_plane.connect()
request_plane = NatsRequestPlane(f"nats://localhost:{NATS_PORT}")
await request_plane.connect()
identity_operator = RemoteOperator("identity", 1, request_plane, data_plane)
target_components = set()
target_component_list: list[uuid.UUID] = []
responding_components = set()
for index in range(num_requests):
request = identity_operator.create_request(
inputs={"input": [index]},
)
target_component = None
if target_component_list:
# we have the list of targets
# only send to workers in that list
target_index = index % len(target_component_list)
target_component = target_component_list[target_index]
identity_operator.component_id = target_component
async for response in await identity_operator.async_infer(request):
responding_component = response.component_id
numpy.testing.assert_equal(
numpy.from_dlpack(response.outputs["output"]), request.inputs["input"]
)
responding_components.add(responding_component)
if not target_component_list:
# add to list of acceptable targets
target_components.add(responding_component)
if len(target_components) >= num_targets:
# finalize list
target_component_list = list(target_components)
timeout = 5
data_plane.close(timeout)
await request_plane.close()
assert target_components == responding_components
def run(num_requests, num_targets=5):
sys.exit(
asyncio.run(
post_requests(
num_requests=num_requests,
num_targets=num_targets,
)
)
)
@pytest.mark.skipif(
"(not os.path.exists('/usr/local/bin/nats-server'))",
reason="NATS.io not present",
)
@pytest.mark.timeout(30)
def test_direct(request, nats_server, workers):
# Using a separate process to use data plane across multiple tests.
p = Process(target=run, args=(50,))
p.start()
p.join()
assert p.exitcode == 0
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import queue
import sys
import time
from functools import partial
from multiprocessing import Process
import cupy
import numpy
import pytest
import tritonclient.grpc as grpcclient
import ucp
from cupy_backends.cuda.api.runtime import CUDARuntimeError
from transformers import XLNetTokenizer
from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.worker.deployment import Deployment
from triton_distributed.worker.log_formatter import LOGGER_NAME
from triton_distributed.worker.operator import OperatorConfig
from triton_distributed.worker.remote_operator import RemoteOperator
from triton_distributed.worker.triton_core_operator import TritonCoreOperator
from triton_distributed.worker.worker import WorkerConfig
from tritonclient.utils import InferenceServerException
from tritonserver import Tensor
NATS_PORT = 4223
MODEL_REPOSITORY = (
"/workspace/worker/tests/python/integration/operators/triton_core_models"
)
OPERATORS_REPOSITORY = "/workspace/worker/tests/python/integration/operators"
TRITON_LOG_FILE = "triton.log"
TRITON_LOG_LEVEL = 6
logger = logging.getLogger(LOGGER_NAME)
# Run cupy's cuda.is_available once to
# avoid the exception hitting runtime code.
try:
if cupy.cuda.is_available():
pass
else:
print("CUDA not available.")
except CUDARuntimeError:
print("CUDA not available")
# TODO
# Decide if this should be
# pre merge, nightly, or weekly
pytestmark = pytest.mark.pre_merge
@pytest.fixture
def workers(request, log_dir):
operator_configs = {}
# Add configs for triton core operators
triton_core_operators = ["preprocessing", "context", "generation", "postprocessing"]
for operator_name in triton_core_operators:
operator_configs[operator_name] = OperatorConfig(
name=operator_name,
implementation=TritonCoreOperator,
version=1,
max_inflight_requests=10,
repository=MODEL_REPOSITORY,
)
# Add configs for other custom operators
operator_name = "mock_disaggregated_serving"
operator_configs[operator_name] = OperatorConfig(
name=operator_name,
implementation="mock_disaggregated_serving:MockDisaggregatedServing",
version=1,
max_inflight_requests=10,
repository=OPERATORS_REPOSITORY,
)
worker_configs = []
test_log_dir = log_dir / request.node.name
test_log_dir.mkdir(parents=True, exist_ok=True)
# We will instantiate a worker for each operator
for name, operator_config in operator_configs.items():
# Set the logging directory
worker_log_dir = test_log_dir / name
worker_configs.append(
WorkerConfig(
request_plane=NatsRequestPlane,
data_plane=UcpDataPlane,
request_plane_args=(
[],
{"request_plane_uri": f"nats://localhost:{NATS_PORT}"},
),
log_level=TRITON_LOG_LEVEL,
log_dir=str(worker_log_dir),
triton_log_path=str(worker_log_dir / TRITON_LOG_FILE),
operators=[operator_config],
)
)
worker_deployment = Deployment(worker_configs)
worker_deployment.start()
yield worker_deployment
worker_deployment.shutdown()
def _create_inputs(number):
inputs = []
outputs = []
for _ in range(number):
request_output_len = 10
query_arr = numpy.array(["This is a sample prompt"], dtype=numpy.object_)
request_output_len_arr = numpy.array([request_output_len], dtype=numpy.int32)
input_ = {"query": query_arr, "request_output_len": request_output_len_arr}
expected_output = numpy.repeat(query_arr, request_output_len)
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
tokens = numpy.array(tokenizer.encode(query_arr[0]))
expected_output = numpy.array(
tokenizer.convert_ids_to_tokens((tokens.tolist()))
)
output_data_ = {"output": Tensor._from_object(expected_output)}
inputs.append(input_)
outputs.append(output_data_)
return inputs, outputs
async def post_requests(num_requests):
ucp.reset()
data_plane = UcpDataPlane()
data_plane.connect()
request_plane = NatsRequestPlane(f"nats://localhost:{NATS_PORT}")
await request_plane.connect()
mock_disaggregated_serving_operator = RemoteOperator(
"mock_disaggregated_serving", 1, request_plane, data_plane
)
expected_results = {}
inputs, outputs = _create_inputs(num_requests)
begin = None
token_latency = []
timeout = True
for i, input_dict in enumerate(inputs):
request_id = str(i)
request = mock_disaggregated_serving_operator.create_request(
inputs=input_dict, request_id=request_id
)
begin = time.time()
response_count = 0
try:
async for response in await mock_disaggregated_serving_operator.async_infer(
inference_request=request
):
token_latency.append(time.time() - begin)
expected_results[request_id] = outputs[i]
if not response.final:
for output_name, expected_value in expected_results[
response.request_id
].items():
output = response.outputs[output_name]
output_value = output.to_bytes_array()
print(f"Final Output: {output_value}")
numpy.testing.assert_equal(
output_value, expected_value.to_bytes_array()
)
response_count += 1
# 1 response from context and 10 responses from generation
assert response_count == 11
except Exception as e:
print("Failed collecting responses:" + repr(e))
del response
print(f"Token latency: {token_latency}")
data_plane.close(wait_for_release=timeout)
await request_plane.close()
raise e
print(f"Token latency: {token_latency}")
data_plane.close(wait_for_release=timeout)
await request_plane.close()
def run(num_requests):
sys.exit(asyncio.run(post_requests(num_requests=num_requests)))
@pytest.mark.skipif(
"(not os.path.exists('/usr/local/bin/nats-server'))",
reason="NATS.io not present or test is not configured to run with mock disaggregated serving",
)
def test_mock_disaggregated_serving(request, nats_server, workers):
# Using a separate process to use data plane across multiple tests.
p = Process(target=run, args=(1,))
p.start()
p.join()
assert p.exitcode == 0
class UserData:
def __init__(self):
self._completed_requests: queue.Queue[
grpcclient.Result | InferenceServerException
] = queue.Queue()
# Define the callback function. Note the last two parameters should be
# result and error. InferenceServerClient would povide the results of an
# inference as grpcclient.InferResult in result. For successful
# inference, error will be None, otherwise it will be an object of
# tritonclientutils.InferenceServerException holding the error details
def callback(user_data, result, error):
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)
async def send_kserve_requests(num_requests):
inputs_dict, outputs_dicts = _create_inputs(num_requests)
inputs = []
inputs.append(grpcclient.InferInput("query", [1], "BYTES"))
inputs.append(grpcclient.InferInput("request_output_len", [1], "INT32"))
user_data = UserData()
with grpcclient.InferenceServerClient("localhost:8001") as client:
client.start_stream(
callback=partial(callback, user_data),
)
for i, input_dict in enumerate(inputs_dict):
inputs[0].set_data_from_numpy(input_dict["query"])
inputs[1].set_data_from_numpy(input_dict["request_output_len"])
client.async_stream_infer(
model_name="mock_disaggregated_serving", inputs=inputs
)
recv_count = 0
while recv_count < 10:
data_item = user_data._completed_requests.get()
recv_count += 1
if isinstance(data_item, InferenceServerException):
raise data_item
else:
result = data_item.as_numpy("output")
print("test \n")
print(result)
# Wait for the tensor clean-up
time.sleep(5)
def run_kserve(num_requests):
sys.exit(asyncio.run(send_kserve_requests(num_requests=num_requests)))
@pytest.mark.skipif(
"(not os.path.exists('/usr/local/bin/nats-server'))",
reason="NATS.io not present",
)
def test_mock_disaggregated_serving_kserve(request, nats_server, workers, api_server):
# Using a separate process to use data plane across multiple tests.
p = Process(target=run_kserve, args=(1,))
p.start()
p.join()
assert p.exitcode == 0
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import asyncio
import logging
import numpy
import pytest
import ucp
from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.worker.deployment import Deployment
from triton_distributed.worker.log_formatter import LOGGER_NAME
from triton_distributed.worker.operator import OperatorConfig
from triton_distributed.worker.remote_operator import RemoteOperator
from triton_distributed.worker.worker import WorkerConfig
NATS_PORT = 4223
MODEL_REPOSITORY = (
"/workspace/worker/tests/python/integration/operators/triton_core_models"
)
OPERATORS_REPOSITORY = "/workspace/worker/tests/python/integration/operators"
TRITON_LOG_FILE = "triton.log"
TRITON_LOG_LEVEL = 6
logger = logging.getLogger(LOGGER_NAME)
# TODO
# Decide if this should be
# pre merge, nightly, or weekly
pytestmark = pytest.mark.pre_merge
@pytest.fixture
def workers(log_dir, request, number_workers=1):
store_outputs_in_response = request.getfixturevalue("store_outputs_in_response")
# Add configs for identity operator
operator_name = "identity"
operator_config = OperatorConfig(
name=operator_name,
implementation="identity:Identity",
version=1,
max_inflight_requests=10,
parameters={"store_outputs_in_response": store_outputs_in_response},
repository=OPERATORS_REPOSITORY,
)
worker_configs = []
test_log_dir = log_dir / request.node.name
test_log_dir.mkdir(parents=True, exist_ok=True)
for i in range(number_workers):
# Set the logging directory
worker_log_dir = test_log_dir / (operator_name + "_" + str(i))
worker_configs.append(
WorkerConfig(
request_plane=NatsRequestPlane,
data_plane=UcpDataPlane,
request_plane_args=(
[],
{"request_plane_uri": f"nats://localhost:{NATS_PORT}"},
),
log_level=TRITON_LOG_LEVEL,
log_dir=str(worker_log_dir),
triton_log_path=str(worker_log_dir / TRITON_LOG_FILE),
operators=[operator_config],
)
)
worker_deployment = Deployment(worker_configs)
worker_deployment.start()
yield worker_deployment
worker_deployment.shutdown()
def _create_inputs(number, tensor_size_in_kb):
inputs = []
outputs = []
elem_cnt = int(tensor_size_in_kb * 1024 / 4)
for _ in range(number):
input_ = numpy.random.randint(low=1, high=100, size=[elem_cnt])
expected_ = {}
expected_["output"] = input_
inputs.append(input_)
outputs.append(expected_)
return inputs, outputs
def run(
aio_benchmark,
store_inputs_in_request,
store_outputs_in_response,
tensor_size_in_kb,
data_plane_tracker,
):
if data_plane_tracker.is_first_run:
ucp.reset()
data_plane_tracker._data_plane = UcpDataPlane()
data_plane_tracker._data_plane.connect()
request_plane = NatsRequestPlane(f"nats://localhost:{NATS_PORT}")
asyncio.get_event_loop().run_until_complete(request_plane.connect())
identity_operator = RemoteOperator(
"identity", 1, request_plane, data_plane_tracker._data_plane
)
inputs, outputs = _create_inputs(1, tensor_size_in_kb)
aio_benchmark(
post_requests,
identity_operator,
inputs,
outputs,
store_inputs_in_request,
store_outputs_in_response,
)
timeout = 5
asyncio.get_event_loop().run_until_complete(request_plane.close())
if data_plane_tracker.is_last_run:
data_plane_tracker._data_plane.close(timeout)
async def post_requests(
identity_model, inputs, outputs, store_inputs_in_request, store_outputs_in_response
):
results = []
expected_results = {}
for i, input_ in enumerate(inputs):
request_id = str(i)
request = identity_model.create_request(
inputs={"input": input_}, request_id=request_id
)
if store_inputs_in_request:
request.store_inputs_in_request.add("input")
results.append(identity_model.async_infer(request))
expected_results[request_id] = outputs[i]
for result in asyncio.as_completed(results):
responses = await result
async for response in responses:
for output_name, expected_value in expected_results[
response.request_id
].items():
output = response.outputs[output_name]
_ = numpy.from_dlpack(output.to_host())
del output
del response
@pytest.fixture(scope="module")
def data_plane_tracker():
class Tracker:
def __init__(self):
self.total_runs = 0
self.current_run = 0
self._data_plane = None
def increment_run(self):
self.current_run += 1
@property
def is_first_run(self):
return self.current_run == 1
@property
def is_last_run(self):
return self.current_run == self.total_runs
return Tracker()
# FIXME: NATS default size limit is 1 MB. However, even when the tensor_size_in_kb
# is set as 600, which corresponds to 0.6144 MB, we are hiting MaxPayloadError.
# Need to investigate why the limit is being hit.
@pytest.mark.skipif(
"(not os.path.exists('/usr/local/bin/nats-server'))",
reason="NATS.io not present or test is configured to run with mock disaggregated_serving",
)
@pytest.mark.parametrize(
["store_inputs_in_request", "store_outputs_in_response"],
[(True, True), (False, False)],
)
@pytest.mark.parametrize(
"tensor_size_in_kb",
[10, 100, 500],
)
@pytest.mark.benchmark(min_rounds=50, max_time=0.5)
def test_identity(
request,
nats_server,
workers,
aio_benchmark,
store_inputs_in_request,
store_outputs_in_response,
tensor_size_in_kb,
data_plane_tracker,
):
"""
This benchmark test checks the latency of a simple operator which returns input in its output
without any processing.
NOTE: We can not use benchmark fixture in the child process. Hence, we are required to use the
same process for opening then data plane object as pytest.
This means that the pytest main process cannot create another data plane object in any other
tests. Hence, we will use a run tracker to open and close the data plane
"""
if data_plane_tracker.total_runs == 0:
data_plane_tracker.total_runs = 6 # Set this to the number of parameters
data_plane_tracker.increment_run()
run(
aio_benchmark,
store_inputs_in_request,
store_outputs_in_response,
tensor_size_in_kb,
data_plane_tracker,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment