Commit 53712d62 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

test: Simplify e2e integration test for worker (#24)


Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent ab274e74
......@@ -36,6 +36,9 @@ CMakeCache.txt
*_pb2.pyi
*.svg
*pytest_report.md
*pytest_report.xml
**/__pycache__
**/venv
*.cache
......
......@@ -36,6 +36,7 @@ class Deployment:
args=[worker_config],
)
)
self._workers[-1].start()
def shutdown(self, join=True, timeout=10):
for worker in self._workers:
......
......@@ -15,17 +15,13 @@
import asyncio
import logging
import multiprocessing
import signal
import subprocess
import sys
import time
import pytest
import pytest_asyncio
from triton_distributed.icp.nats_request_plane import NatsServer
from triton_distributed.worker.log_formatter import LOGGER_NAME, setup_logger
from triton_distributed.worker.worker import Worker
from triton_distributed.worker.log_formatter import LOGGER_NAME
logger = logging.getLogger(LOGGER_NAME)
......@@ -36,95 +32,6 @@ TEST_API_SERVER_MODEL_REPO_PATH = (
)
async def _wait_for_tasks(loop):
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
try:
await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
print("Encountered an error in task clean-up: %s", e)
print("Stopping the event loop")
loop.stop()
def _run_worker(name, queue, worker_config):
tensor_store_keys = None
try:
with open(f"{name}.worker.stdout.log", "w") as output_:
with open(f"{name}.worker.stderr.log", "w") as output_err:
with open(f"{name}.worker.triton.log", "w"):
sys.stdout = output_
sys.stderr = output_err
triton_log_filename = f"{name}.worker.triton.log"
setup_logger(log_level=worker_config.log_level)
worker_config.triton_log_file = triton_log_filename
worker_config.name = name
try:
worker = Worker(worker_config)
except Exception as e:
queue.put(f"Failed to start {name}: {e}")
logger.exception("Failed to instantiate a worker class")
loop = asyncio.new_event_loop()
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for sig in signals:
loop.add_signal_handler(
sig, lambda s=sig: asyncio.create_task(worker.shutdown(s)) # type: ignore
)
try:
queue.put("READY")
loop.run_until_complete(worker.serve())
except asyncio.CancelledError:
print("server cancellation detected")
finally:
loop.run_until_complete(_wait_for_tasks(loop))
loop.close()
tensor_store_keys = list(
worker._data_plane._tensor_store.keys()
)
sys.exit(len(tensor_store_keys))
except Exception as e:
print(f"Worker Serving Failed to start: {e}")
queue.put(f"Failed to start {name}: {e}")
raise e
class WorkerManager:
ctx = multiprocessing.get_context("spawn")
@staticmethod
def setup_worker_process(operators, name, queue, worker_config):
worker_config.name = name
worker_config.operators = operators
process = WorkerManager.ctx.Process(
target=_run_worker,
args=(name, queue, worker_config),
name=name,
)
process.start()
return process
@staticmethod
def cleanup_workers(workers, check_status=True):
for worker in workers:
print(f"Terminating {worker.name} worker", flush=True)
worker.terminate()
for worker in workers:
worker.join()
print(f"{worker.name} exited with {worker.exitcode} stored tensors")
assert (
worker.exitcode == 0 if check_status else True
), f"{worker.name} exited with {worker.exitcode} stored tensors"
@pytest.fixture
def worker_manager():
return WorkerManager
@pytest.fixture(scope="session")
def nats_server():
server = NatsServer()
......
......@@ -15,8 +15,10 @@
import asyncio
import logging
import shutil
import sys
from multiprocessing import Manager, Process
from multiprocessing import Process
from pathlib import Path
import cupy
import numpy
......@@ -25,6 +27,7 @@ import ucp
from cupy_backends.cuda.api.runtime import CUDARuntimeError
from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.worker.deployment import Deployment
from triton_distributed.worker.log_formatter import LOGGER_NAME
from triton_distributed.worker.operator import OperatorConfig
from triton_distributed.worker.remote_operator import RemoteOperator
......@@ -32,8 +35,11 @@ from triton_distributed.worker.triton_core_operator import TritonCoreOperator
from triton_distributed.worker.worker import WorkerConfig
NATS_PORT = 4223
MODEL_REPOSITORY = "/workspace/worker/tests/python/integration/operators/models"
WORKFLOW_REPOSITORY = "/workspace/worker/tests/python/integration/operators"
MODEL_REPOSITORY = (
"/workspace/worker/tests/python/integration/operators/triton_core_models"
)
OPERATORS_REPOSITORY = "/workspace/worker/tests/python/integration/operators"
TRITON_LOG_FILE = "triton.log"
TRITON_LOG_LEVEL = 6
logger = logging.getLogger(LOGGER_NAME)
......@@ -55,93 +61,64 @@ pytestmark = pytest.mark.pre_merge
@pytest.fixture
def workers(worker_manager, request):
worker_config = WorkerConfig(
request_plane=NatsRequestPlane,
data_plane=UcpDataPlane,
request_plane_args=([], {"request_plane_uri": f"nats://localhost:{NATS_PORT}"}),
log_level=TRITON_LOG_LEVEL,
)
store_outputs_in_response = request.getfixturevalue("store_outputs_in_response")
def workers(request):
operator_configs = {}
add_model = OperatorConfig(
name="add",
implementation=TritonCoreOperator,
version=1,
max_inflight_requests=10,
parameters={"store_outputs_in_response": store_outputs_in_response},
repository=MODEL_REPOSITORY,
)
multiply_model = OperatorConfig(
name="multiply",
implementation=TritonCoreOperator,
version=1,
max_inflight_requests=10,
parameters={"store_outputs_in_response": store_outputs_in_response},
repository=MODEL_REPOSITORY,
)
divide_model = OperatorConfig(
name="divide",
store_outputs_in_response = request.getfixturevalue("store_outputs_in_response")
# Add configs for triton core operators
triton_core_operators = ["add", "multiply", "divide"]
for operator_name in triton_core_operators:
operator_configs[operator_name] = OperatorConfig(
name=operator_name,
implementation=TritonCoreOperator,
version=1,
max_inflight_requests=10,
parameters={"store_outputs_in_response": store_outputs_in_response},
repository=MODEL_REPOSITORY,
)
workflow = OperatorConfig(
name="add_multiply_divide",
# Add configs for other custom operators
operator_name = "add_multiply_divide"
operator_configs[operator_name] = OperatorConfig(
name=operator_name,
implementation="add_multiply_divide:AddMultiplyDivide",
version=1,
max_inflight_requests=10,
parameters={"store_outputs_in_response": store_outputs_in_response},
repository=WORKFLOW_REPOSITORY,
repository=OPERATORS_REPOSITORY,
)
with Manager() as manager:
workers = []
queues = []
worker_configs = []
queues.append(manager.Queue(maxsize=1))
workers.append(
worker_manager.setup_worker_process(
[add_model], "add", queues[-1], worker_config
)
)
test_log_directory = Path(__file__).with_suffix("")
if test_log_directory.exists():
shutil.rmtree(test_log_directory, ignore_errors=True)
test_log_directory.mkdir()
queues.append(manager.Queue(maxsize=1))
workers.append(
worker_manager.setup_worker_process(
[multiply_model], "multiply", queues[-1], worker_config
)
)
queues.append(manager.Queue(maxsize=1))
workers.append(
worker_manager.setup_worker_process(
[divide_model], "divide", queues[-1], worker_config
)
)
queues.append(manager.Queue(maxsize=1))
workers.append(
worker_manager.setup_worker_process(
[workflow], "add_multiply_divide", queues[-1], worker_config
# We will instantiate a worker for each operator
for name, operator_config in operator_configs.items():
# Set the logging directory
log_dir = test_log_directory / name
worker_configs.append(
WorkerConfig(
request_plane=NatsRequestPlane,
data_plane=UcpDataPlane,
request_plane_args=(
[],
{"request_plane_uri": f"nats://localhost:{NATS_PORT}"},
),
log_level=TRITON_LOG_LEVEL,
log_dir=str(log_dir),
triton_log_path=str(log_dir / TRITON_LOG_FILE),
operators=[operator_config],
)
)
workers_failed = False
status_list = []
for queue, worker in zip(queues, workers):
status = queue.get()
status_list.append(status)
if status != "READY":
workers_failed = True
worker_deployment = Deployment(worker_configs)
if workers_failed:
worker_manager.cleanup_workers(workers, check_status=False)
raise Exception(f"Failed to start worker processes: {status_list}")
yield workers
worker_manager.cleanup_workers(workers)
worker_deployment.start()
yield worker_deployment
worker_deployment.shutdown()
def _create_inputs(number, size):
......@@ -174,6 +151,9 @@ def _create_inputs(number, size):
async def post_requests(num_requests, store_inputs_in_request):
"""
Post requests to add_multiply_divide operator.
"""
ucp.reset()
timeout = 5
......@@ -183,7 +163,7 @@ async def post_requests(num_requests, store_inputs_in_request):
request_plane = NatsRequestPlane(f"nats://localhost:{NATS_PORT}")
await request_plane.connect()
add_multiply_divide_model = RemoteOperator(
add_multiply_divide_operator = RemoteOperator(
"add_multiply_divide", 1, request_plane, data_plane
)
......@@ -194,13 +174,13 @@ async def post_requests(num_requests, store_inputs_in_request):
for i, input_ in enumerate(inputs):
request_id = str(i)
request = add_multiply_divide_model.create_request(
request = add_multiply_divide_operator.create_request(
inputs={"int64_input": input_}, request_id=request_id
)
if store_inputs_in_request:
request.store_inputs_in_request.add("int64_input")
print(request)
results.append(add_multiply_divide_model.async_infer(request))
results.append(add_multiply_divide_operator.async_infer(request))
expected_results[request_id] = outputs[i]
for result in asyncio.as_completed(results):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment