"pcdet/ops/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "19068b52a182c04694305c4542fede9dddc4d527"
Commit 53712d62 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

test: Simplify e2e integration test for worker (#24)


Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent ab274e74
...@@ -36,6 +36,9 @@ CMakeCache.txt ...@@ -36,6 +36,9 @@ CMakeCache.txt
*_pb2.pyi *_pb2.pyi
*.svg *.svg
*pytest_report.md
*pytest_report.xml
**/__pycache__ **/__pycache__
**/venv **/venv
*.cache *.cache
......
...@@ -36,6 +36,7 @@ class Deployment: ...@@ -36,6 +36,7 @@ class Deployment:
args=[worker_config], args=[worker_config],
) )
) )
self._workers[-1].start()
def shutdown(self, join=True, timeout=10): def shutdown(self, join=True, timeout=10):
for worker in self._workers: for worker in self._workers:
......
...@@ -15,17 +15,13 @@ ...@@ -15,17 +15,13 @@
import asyncio import asyncio
import logging import logging
import multiprocessing
import signal
import subprocess import subprocess
import sys
import time import time
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from triton_distributed.icp.nats_request_plane import NatsServer from triton_distributed.icp.nats_request_plane import NatsServer
from triton_distributed.worker.log_formatter import LOGGER_NAME, setup_logger from triton_distributed.worker.log_formatter import LOGGER_NAME
from triton_distributed.worker.worker import Worker
logger = logging.getLogger(LOGGER_NAME) logger = logging.getLogger(LOGGER_NAME)
...@@ -36,95 +32,6 @@ TEST_API_SERVER_MODEL_REPO_PATH = ( ...@@ -36,95 +32,6 @@ TEST_API_SERVER_MODEL_REPO_PATH = (
) )
async def _wait_for_tasks(loop):
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
try:
await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
print("Encountered an error in task clean-up: %s", e)
print("Stopping the event loop")
loop.stop()
def _run_worker(name, queue, worker_config):
tensor_store_keys = None
try:
with open(f"{name}.worker.stdout.log", "w") as output_:
with open(f"{name}.worker.stderr.log", "w") as output_err:
with open(f"{name}.worker.triton.log", "w"):
sys.stdout = output_
sys.stderr = output_err
triton_log_filename = f"{name}.worker.triton.log"
setup_logger(log_level=worker_config.log_level)
worker_config.triton_log_file = triton_log_filename
worker_config.name = name
try:
worker = Worker(worker_config)
except Exception as e:
queue.put(f"Failed to start {name}: {e}")
logger.exception("Failed to instantiate a worker class")
loop = asyncio.new_event_loop()
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for sig in signals:
loop.add_signal_handler(
sig, lambda s=sig: asyncio.create_task(worker.shutdown(s)) # type: ignore
)
try:
queue.put("READY")
loop.run_until_complete(worker.serve())
except asyncio.CancelledError:
print("server cancellation detected")
finally:
loop.run_until_complete(_wait_for_tasks(loop))
loop.close()
tensor_store_keys = list(
worker._data_plane._tensor_store.keys()
)
sys.exit(len(tensor_store_keys))
except Exception as e:
print(f"Worker Serving Failed to start: {e}")
queue.put(f"Failed to start {name}: {e}")
raise e
class WorkerManager:
ctx = multiprocessing.get_context("spawn")
@staticmethod
def setup_worker_process(operators, name, queue, worker_config):
worker_config.name = name
worker_config.operators = operators
process = WorkerManager.ctx.Process(
target=_run_worker,
args=(name, queue, worker_config),
name=name,
)
process.start()
return process
@staticmethod
def cleanup_workers(workers, check_status=True):
for worker in workers:
print(f"Terminating {worker.name} worker", flush=True)
worker.terminate()
for worker in workers:
worker.join()
print(f"{worker.name} exited with {worker.exitcode} stored tensors")
assert (
worker.exitcode == 0 if check_status else True
), f"{worker.name} exited with {worker.exitcode} stored tensors"
@pytest.fixture
def worker_manager():
return WorkerManager
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def nats_server(): def nats_server():
server = NatsServer() server = NatsServer()
......
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
import asyncio import asyncio
import logging import logging
import shutil
import sys import sys
from multiprocessing import Manager, Process from multiprocessing import Process
from pathlib import Path
import cupy import cupy
import numpy import numpy
...@@ -25,6 +27,7 @@ import ucp ...@@ -25,6 +27,7 @@ import ucp
from cupy_backends.cuda.api.runtime import CUDARuntimeError from cupy_backends.cuda.api.runtime import CUDARuntimeError
from triton_distributed.icp.nats_request_plane import NatsRequestPlane from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.worker.deployment import Deployment
from triton_distributed.worker.log_formatter import LOGGER_NAME from triton_distributed.worker.log_formatter import LOGGER_NAME
from triton_distributed.worker.operator import OperatorConfig from triton_distributed.worker.operator import OperatorConfig
from triton_distributed.worker.remote_operator import RemoteOperator from triton_distributed.worker.remote_operator import RemoteOperator
...@@ -32,8 +35,11 @@ from triton_distributed.worker.triton_core_operator import TritonCoreOperator ...@@ -32,8 +35,11 @@ from triton_distributed.worker.triton_core_operator import TritonCoreOperator
from triton_distributed.worker.worker import WorkerConfig from triton_distributed.worker.worker import WorkerConfig
NATS_PORT = 4223 NATS_PORT = 4223
MODEL_REPOSITORY = "/workspace/worker/tests/python/integration/operators/models" MODEL_REPOSITORY = (
WORKFLOW_REPOSITORY = "/workspace/worker/tests/python/integration/operators" "/workspace/worker/tests/python/integration/operators/triton_core_models"
)
OPERATORS_REPOSITORY = "/workspace/worker/tests/python/integration/operators"
TRITON_LOG_FILE = "triton.log"
TRITON_LOG_LEVEL = 6 TRITON_LOG_LEVEL = 6
logger = logging.getLogger(LOGGER_NAME) logger = logging.getLogger(LOGGER_NAME)
...@@ -55,93 +61,64 @@ pytestmark = pytest.mark.pre_merge ...@@ -55,93 +61,64 @@ pytestmark = pytest.mark.pre_merge
@pytest.fixture @pytest.fixture
def workers(worker_manager, request): def workers(request):
worker_config = WorkerConfig( operator_configs = {}
request_plane=NatsRequestPlane,
data_plane=UcpDataPlane,
request_plane_args=([], {"request_plane_uri": f"nats://localhost:{NATS_PORT}"}),
log_level=TRITON_LOG_LEVEL,
)
store_outputs_in_response = request.getfixturevalue("store_outputs_in_response")
add_model = OperatorConfig( store_outputs_in_response = request.getfixturevalue("store_outputs_in_response")
name="add", # Add configs for triton core operators
implementation=TritonCoreOperator, triton_core_operators = ["add", "multiply", "divide"]
version=1, for operator_name in triton_core_operators:
max_inflight_requests=10, operator_configs[operator_name] = OperatorConfig(
parameters={"store_outputs_in_response": store_outputs_in_response}, name=operator_name,
repository=MODEL_REPOSITORY,
)
multiply_model = OperatorConfig(
name="multiply",
implementation=TritonCoreOperator,
version=1,
max_inflight_requests=10,
parameters={"store_outputs_in_response": store_outputs_in_response},
repository=MODEL_REPOSITORY,
)
divide_model = OperatorConfig(
name="divide",
implementation=TritonCoreOperator, implementation=TritonCoreOperator,
version=1, version=1,
max_inflight_requests=10, max_inflight_requests=10,
parameters={"store_outputs_in_response": store_outputs_in_response}, parameters={"store_outputs_in_response": store_outputs_in_response},
repository=MODEL_REPOSITORY, repository=MODEL_REPOSITORY,
) )
workflow = OperatorConfig(
name="add_multiply_divide", # Add configs for other custom operators
operator_name = "add_multiply_divide"
operator_configs[operator_name] = OperatorConfig(
name=operator_name,
implementation="add_multiply_divide:AddMultiplyDivide", implementation="add_multiply_divide:AddMultiplyDivide",
version=1, version=1,
max_inflight_requests=10, max_inflight_requests=10,
parameters={"store_outputs_in_response": store_outputs_in_response}, parameters={"store_outputs_in_response": store_outputs_in_response},
repository=WORKFLOW_REPOSITORY, repository=OPERATORS_REPOSITORY,
) )
with Manager() as manager: worker_configs = []
workers = []
queues = []
queues.append(manager.Queue(maxsize=1)) test_log_directory = Path(__file__).with_suffix("")
workers.append( if test_log_directory.exists():
worker_manager.setup_worker_process( shutil.rmtree(test_log_directory, ignore_errors=True)
[add_model], "add", queues[-1], worker_config test_log_directory.mkdir()
)
)
queues.append(manager.Queue(maxsize=1)) # We will instantiate a worker for each operator
workers.append( for name, operator_config in operator_configs.items():
worker_manager.setup_worker_process( # Set the logging directory
[multiply_model], "multiply", queues[-1], worker_config log_dir = test_log_directory / name
) worker_configs.append(
) WorkerConfig(
request_plane=NatsRequestPlane,
queues.append(manager.Queue(maxsize=1)) data_plane=UcpDataPlane,
workers.append( request_plane_args=(
worker_manager.setup_worker_process( [],
[divide_model], "divide", queues[-1], worker_config {"request_plane_uri": f"nats://localhost:{NATS_PORT}"},
) ),
) log_level=TRITON_LOG_LEVEL,
log_dir=str(log_dir),
queues.append(manager.Queue(maxsize=1)) triton_log_path=str(log_dir / TRITON_LOG_FILE),
workers.append( operators=[operator_config],
worker_manager.setup_worker_process(
[workflow], "add_multiply_divide", queues[-1], worker_config
) )
) )
workers_failed = False worker_deployment = Deployment(worker_configs)
status_list = []
for queue, worker in zip(queues, workers):
status = queue.get()
status_list.append(status)
if status != "READY":
workers_failed = True
if workers_failed: worker_deployment.start()
worker_manager.cleanup_workers(workers, check_status=False) yield worker_deployment
raise Exception(f"Failed to start worker processes: {status_list}") worker_deployment.shutdown()
yield workers
worker_manager.cleanup_workers(workers)
def _create_inputs(number, size): def _create_inputs(number, size):
...@@ -174,6 +151,9 @@ def _create_inputs(number, size): ...@@ -174,6 +151,9 @@ def _create_inputs(number, size):
async def post_requests(num_requests, store_inputs_in_request): async def post_requests(num_requests, store_inputs_in_request):
"""
Post requests to add_multiply_divide operator.
"""
ucp.reset() ucp.reset()
timeout = 5 timeout = 5
...@@ -183,7 +163,7 @@ async def post_requests(num_requests, store_inputs_in_request): ...@@ -183,7 +163,7 @@ async def post_requests(num_requests, store_inputs_in_request):
request_plane = NatsRequestPlane(f"nats://localhost:{NATS_PORT}") request_plane = NatsRequestPlane(f"nats://localhost:{NATS_PORT}")
await request_plane.connect() await request_plane.connect()
add_multiply_divide_model = RemoteOperator( add_multiply_divide_operator = RemoteOperator(
"add_multiply_divide", 1, request_plane, data_plane "add_multiply_divide", 1, request_plane, data_plane
) )
...@@ -194,13 +174,13 @@ async def post_requests(num_requests, store_inputs_in_request): ...@@ -194,13 +174,13 @@ async def post_requests(num_requests, store_inputs_in_request):
for i, input_ in enumerate(inputs): for i, input_ in enumerate(inputs):
request_id = str(i) request_id = str(i)
request = add_multiply_divide_model.create_request( request = add_multiply_divide_operator.create_request(
inputs={"int64_input": input_}, request_id=request_id inputs={"int64_input": input_}, request_id=request_id
) )
if store_inputs_in_request: if store_inputs_in_request:
request.store_inputs_in_request.add("int64_input") request.store_inputs_in_request.add("int64_input")
print(request) print(request)
results.append(add_multiply_divide_model.async_infer(request)) results.append(add_multiply_divide_operator.async_infer(request))
expected_results[request_id] = outputs[i] expected_results[request_id] = outputs[i]
for result in asyncio.as_completed(results): for result in asyncio.as_completed(results):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment