Commit 4698c0f4 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

feat: hello world


Co-authored-by: default avatarPiotr Marcinkiewicz <piotrm@nvidia.com>
Co-authored-by: default avatarTanmay Verma <tanmay2592@gmail.com>
parent e6c12674
...@@ -29,14 +29,17 @@ from tritonserver import InvalidArgumentError ...@@ -29,14 +29,17 @@ from tritonserver import InvalidArgumentError
class RemoteOperator: class RemoteOperator:
def __init__( def __init__(
self, self,
name: str, operator: str | tuple[str, int],
version: int,
request_plane: RequestPlane, request_plane: RequestPlane,
data_plane: DataPlane, data_plane: DataPlane,
component_id: Optional[uuid.UUID] = None, component_id: Optional[uuid.UUID] = None,
): ):
self.name = name if isinstance(operator, str):
self.version = version self.name = operator
self.version = 1
else:
self.name = operator[0]
self.version = operator[1]
self._request_plane = request_plane self._request_plane = request_plane
self._data_plane = data_plane self._data_plane = data_plane
self.component_id = component_id self.component_id = component_id
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
import asyncio import asyncio
import importlib import importlib
import logging import logging
import multiprocessing
import os import os
import pathlib import pathlib
import signal import signal
...@@ -50,7 +49,7 @@ class WorkerConfig: ...@@ -50,7 +49,7 @@ class WorkerConfig:
data_plane: Type[DataPlane] = UcpDataPlane data_plane: Type[DataPlane] = UcpDataPlane
request_plane_args: tuple[list, dict] = field(default_factory=lambda: ([], {})) request_plane_args: tuple[list, dict] = field(default_factory=lambda: ([], {}))
data_plane_args: tuple[list, dict] = field(default_factory=lambda: ([], {})) data_plane_args: tuple[list, dict] = field(default_factory=lambda: ([], {}))
log_level: int = 0 log_level: Optional[int] = None
operators: list[OperatorConfig] = field(default_factory=list) operators: list[OperatorConfig] = field(default_factory=list)
triton_log_path: Optional[str] = None triton_log_path: Optional[str] = None
name: str = str(uuid.uuid1()) name: str = str(uuid.uuid1())
...@@ -75,6 +74,8 @@ class Worker: ...@@ -75,6 +74,8 @@ class Worker:
self._triton_log_path = config.triton_log_path self._triton_log_path = config.triton_log_path
self._name = config.name self._name = config.name
self._log_level = config.log_level self._log_level = config.log_level
if self._log_level is None:
self._log_level = 0
self._operator_configs = config.operators self._operator_configs = config.operators
self._log_dir = config.log_dir self._log_dir = config.log_dir
...@@ -87,6 +88,7 @@ class Worker: ...@@ -87,6 +88,7 @@ class Worker:
self._operators: dict[tuple[str, int], Operator] = {} self._operators: dict[tuple[str, int], Operator] = {}
self._metrics_port = config.metrics_port self._metrics_port = config.metrics_port
self._metrics_server: Optional[uvicorn.Server] = None self._metrics_server: Optional[uvicorn.Server] = None
self._component_id = self._request_plane.component_id
def _import_operators(self): def _import_operators(self):
for operator_config in self._operator_configs: for operator_config in self._operator_configs:
...@@ -225,6 +227,7 @@ class Worker: ...@@ -225,6 +227,7 @@ class Worker:
await asyncio.gather(*handlers) await asyncio.gather(*handlers)
async def serve(self): async def serve(self):
error = None
self._triton_core = tritonserver.Server( self._triton_core = tritonserver.Server(
model_repository=".", model_repository=".",
log_error=True, log_error=True,
...@@ -258,6 +261,7 @@ class Worker: ...@@ -258,6 +261,7 @@ class Worker:
except Exception as e: except Exception as e:
logger.exception("Encountered an error in worker: %s", e) logger.exception("Encountered an error in worker: %s", e)
self._stop_requested = True self._stop_requested = True
error = e
logger.info("worker store: %s", list(self._data_plane._tensor_store.keys())) logger.info("worker store: %s", list(self._data_plane._tensor_store.keys()))
logger.info("Worker stopped...") logger.info("Worker stopped...")
logger.info( logger.info(
...@@ -272,6 +276,7 @@ class Worker: ...@@ -272,6 +276,7 @@ class Worker:
if self._metrics_server: if self._metrics_server:
self._metrics_server.should_exit = True self._metrics_server.should_exit = True
await self._metrics_server.shutdown() await self._metrics_server.shutdown()
return error
async def shutdown(self, signal): async def shutdown(self, signal):
logger.info("Received exit signal %s...", signal.name) logger.info("Received exit signal %s...", signal.name)
...@@ -326,13 +331,20 @@ class Worker: ...@@ -326,13 +331,20 @@ class Worker:
loop.stop() loop.stop()
def start(self): def start(self):
exit_condition = None
if self._log_dir: if self._log_dir:
pid = os.getpid()
os.makedirs(self._log_dir, exist_ok=True) os.makedirs(self._log_dir, exist_ok=True)
stdout_path = os.path.join(self._log_dir, f"{self._name}.stdout.log") stdout_path = os.path.join(
stderr_path = os.path.join(self._log_dir, f"{self._name}.stderr.log") self._log_dir, f"{self._name}.{self._component_id}.{pid}.stdout.log"
)
stderr_path = os.path.join(
self._log_dir, f"{self._name}.{self._component_id}.{pid}.stderr.log"
)
if not self._triton_log_path: if not self._triton_log_path:
self._triton_log_path = os.path.join( self._triton_log_path = os.path.join(
self._log_dir, f"{self._name}.triton.log" self._log_dir, f"{self._name}.{self._component_id}.{pid}.triton.log"
) )
sys.stdout = open(stdout_path, "w", buffering=1) sys.stdout = open(stdout_path, "w", buffering=1)
sys.stderr = open(stderr_path, "w", buffering=1) sys.stderr = open(stderr_path, "w", buffering=1)
...@@ -349,55 +361,34 @@ class Worker: ...@@ -349,55 +361,34 @@ class Worker:
loop.add_signal_handler( loop.add_signal_handler(
sig, lambda s=sig: asyncio.create_task(self.shutdown(s)) # type: ignore sig, lambda s=sig: asyncio.create_task(self.shutdown(s)) # type: ignore
) )
serve_result = None
try: try:
if self._metrics_port: if self._metrics_port:
loop.create_task(self.serve()) serve_result = loop.create_task(self.serve())
self._metrics_server = self._setup_metrics_server() self._metrics_server = self._setup_metrics_server()
assert self._metrics_server, "Unable to start metrics server" assert self._metrics_server, "Unable to start metrics server"
loop.run_until_complete(self._metrics_server.serve()) loop.run_until_complete(self._metrics_server.serve())
else: else:
loop.run_until_complete(self.serve()) serve_result = loop.run_until_complete(self.serve())
except asyncio.CancelledError: except asyncio.CancelledError:
pass
logger.info("Worker cancelled!") logger.info("Worker cancelled!")
finally: finally:
loop.run_until_complete(self._wait_for_tasks(loop)) loop.run_until_complete(self._wait_for_tasks(loop))
loop.close() loop.close()
logger.info("Successfully shutdown worker.") logger.info("Successfully shutdown worker.")
if isinstance(serve_result, asyncio.Task):
exit_condition = serve_result.result()
else:
exit_condition = serve_result
sys.stdout.flush() sys.stdout.flush()
sys.stderr.flush() sys.stderr.flush()
if self._log_dir: if self._log_dir:
sys.stdout.close() sys.stdout.close()
sys.stderr.close() sys.stderr.close()
if exit_condition is not None:
class Deployment: sys.exit(1)
def __init__(self, worker_configs: list[WorkerConfig]): else:
self._process_context = multiprocessing.get_context("spawn") sys.exit(0)
self._worker_configs = worker_configs
self._workers: list[multiprocessing.context.SpawnProcess] = []
@staticmethod
def _start_worker(worker_config):
Worker(worker_config).start()
def start(self):
for worker_config in self._worker_configs:
self._workers.append(
self._process_context.Process(
target=Deployment._start_worker,
name=worker_config.name,
args=[worker_config],
)
)
def shutdown(self, join=True, timeout=10):
for worker in self._workers:
worker.terminate()
if join:
for worker in self._workers:
worker.join(timeout)
for worker in self._workers:
if worker.is_alive():
worker.kill()
worker.join(timeout)
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# # SPDX-License-Identifier: Apache-2.0
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions # Licensed under the Apache License, Version 2.0 (the "License");
# are met: # you may not use this file except in compliance with the License.
# * Redistributions of source code must retain the above copyright # You may obtain a copy of the License at
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright # http://www.apache.org/licenses/LICENSE-2.0
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution. # Unless required by applicable law or agreed to in writing, software
# * Neither the name of NVIDIA CORPORATION nor the names of its # distributed under the License is distributed on an "AS IS" BASIS,
# contributors may be used to endorse or promote products derived # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# from this software without specific prior written permission. # See the License for the specific language governing permissions and
# # limitations under the License.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import asyncio import asyncio
import gc import gc
...@@ -162,7 +151,7 @@ class TritonPythonModel: ...@@ -162,7 +151,7 @@ class TritonPythonModel:
"string_value" "string_value"
] ]
self._remote_operator = RemoteOperator( self._remote_operator = RemoteOperator(
self._remote_worker_name, 1, self._request_plane, self._data_plane self._remote_worker_name, self._request_plane, self._data_plane
) )
# Starting the response thread. It allows API Server to keep making progress while # Starting the response thread. It allows API Server to keep making progress while
......
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# # SPDX-License-Identifier: Apache-2.0
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions # Licensed under the Apache License, Version 2.0 (the "License");
# are met: # you may not use this file except in compliance with the License.
# * Redistributions of source code must retain the above copyright # You may obtain a copy of the License at
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright # http://www.apache.org/licenses/LICENSE-2.0
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution. # Unless required by applicable law or agreed to in writing, software
# * Neither the name of NVIDIA CORPORATION nor the names of its # distributed under the License is distributed on an "AS IS" BASIS,
# contributors may be used to endorse or promote products derived # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# from this software without specific prior written permission. # See the License for the specific language governing permissions and
# # limitations under the License.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
name: "mock_disaggregated_serving" name: "mock_disaggregated_serving"
backend: "python" backend: "python"
......
...@@ -35,14 +35,12 @@ class AddMultiplyDivide(Operator): ...@@ -35,14 +35,12 @@ class AddMultiplyDivide(Operator):
self._request_plane = request_plane self._request_plane = request_plane
self._data_plane = data_plane self._data_plane = data_plane
self._parameters = parameters self._parameters = parameters
self._add_model = RemoteOperator( self._add_model = RemoteOperator("add", self._request_plane, self._data_plane)
"add", 1, self._request_plane, self._data_plane
)
self._multiply_model = RemoteOperator( self._multiply_model = RemoteOperator(
"multiply", 1, self._request_plane, self._data_plane "multiply", self._request_plane, self._data_plane
) )
self._divide_model = RemoteOperator( self._divide_model = RemoteOperator(
"divide", 1, self._request_plane, self._data_plane "divide", self._request_plane, self._data_plane
) )
async def execute(self, requests: list[RemoteInferenceRequest]): async def execute(self, requests: list[RemoteInferenceRequest]):
......
...@@ -37,16 +37,16 @@ class MockDisaggregatedServing(Operator): ...@@ -37,16 +37,16 @@ class MockDisaggregatedServing(Operator):
self._data_plane = data_plane self._data_plane = data_plane
self._params = params self._params = params
self._preprocessing_model = RemoteOperator( self._preprocessing_model = RemoteOperator(
"preprocessing", 1, self._request_plane, self._data_plane "preprocessing", self._request_plane, self._data_plane
) )
self._context_model = RemoteOperator( self._context_model = RemoteOperator(
"context", 1, self._request_plane, self._data_plane "context", self._request_plane, self._data_plane
) )
self._generate_model = RemoteOperator( self._generate_model = RemoteOperator(
"generation", 1, self._request_plane, self._data_plane "generation", self._request_plane, self._data_plane
) )
self._postprocessing_model = RemoteOperator( self._postprocessing_model = RemoteOperator(
"postprocessing", 1, self._request_plane, self._data_plane "postprocessing", self._request_plane, self._data_plane
) )
self._logger = logger self._logger = logger
......
...@@ -160,7 +160,7 @@ async def post_requests(num_requests, store_inputs_in_request): ...@@ -160,7 +160,7 @@ async def post_requests(num_requests, store_inputs_in_request):
await request_plane.connect() await request_plane.connect()
add_multiply_divide_operator = RemoteOperator( add_multiply_divide_operator = RemoteOperator(
"add_multiply_divide", 1, request_plane, data_plane "add_multiply_divide", request_plane, data_plane
) )
results = [] results = []
......
...@@ -115,7 +115,7 @@ async def post_requests(num_requests, num_targets): ...@@ -115,7 +115,7 @@ async def post_requests(num_requests, num_targets):
request_plane = NatsRequestPlane(f"nats://localhost:{NATS_PORT}") request_plane = NatsRequestPlane(f"nats://localhost:{NATS_PORT}")
await request_plane.connect() await request_plane.connect()
identity_operator = RemoteOperator("identity", 1, request_plane, data_plane) identity_operator = RemoteOperator("identity", request_plane, data_plane)
target_components = set() target_components = set()
target_component_list: list[uuid.UUID] = [] target_component_list: list[uuid.UUID] = []
......
...@@ -156,7 +156,7 @@ async def post_requests(num_requests): ...@@ -156,7 +156,7 @@ async def post_requests(num_requests):
await request_plane.connect() await request_plane.connect()
mock_disaggregated_serving_operator = RemoteOperator( mock_disaggregated_serving_operator = RemoteOperator(
"mock_disaggregated_serving", 1, request_plane, data_plane "mock_disaggregated_serving", request_plane, data_plane
) )
expected_results = {} expected_results = {}
......
...@@ -133,7 +133,7 @@ def run( ...@@ -133,7 +133,7 @@ def run(
asyncio.get_event_loop().run_until_complete(request_plane.connect()) asyncio.get_event_loop().run_until_complete(request_plane.connect())
identity_operator = RemoteOperator( identity_operator = RemoteOperator(
"identity", 1, request_plane, data_plane_tracker._data_plane "identity", request_plane, data_plane_tracker._data_plane
) )
inputs, outputs = _create_inputs(1, tensor_size_in_kb) inputs, outputs = _create_inputs(1, tensor_size_in_kb)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment