Commit b9a0ce2c authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

Clean-up vLLM example to use Operator API (#68)


Signed-off-by: default avatarptarasiewiczNV <104908264+ptarasiewiczNV@users.noreply.github.com>
Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
parent 022b6db5
......@@ -19,9 +19,9 @@ import time
from pathlib import Path
from llm.vllm.operators.vllm import (
VllmBaselineOperator,
VllmContextOperator,
VllmGenerateOperator,
VllmOperator,
)
from triton_distributed.worker import Deployment, OperatorConfig, WorkerConfig
......@@ -69,7 +69,7 @@ def _create_generate_op(name, args, max_inflight_requests):
def _create_baseline_op(name, args, max_inflight_requests):
return OperatorConfig(
name=name,
implementation=VllmBaselineOperator,
implementation=VllmOperator,
max_inflight_requests=int(max_inflight_requests),
parameters=vars(args),
)
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import inspect
import os
import time
......@@ -29,7 +30,15 @@ LOGGER = vllm.logger.init_logger(__name__)
RETURN_EVERY_N = 1000000
class SingleComputePipeline:
class Stage(abc.ABC):
@abc.abstractmethod
async def __call__(
self, input_payload: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
yield {}
class AggregatedStage(Stage):
def __init__(
self,
**kwargs,
......@@ -80,7 +89,7 @@ class SingleComputePipeline:
yield {"outputs": {}, "error": str(e), "final": True}
class PrefillStage:
class PrefillStage(Stage):
def __init__(
self,
generate_tensor_parallel_size: Optional[int] = None,
......@@ -159,7 +168,6 @@ class PrefillStage:
"outputs": {}, # See line 195 for context
"error": None,
"parameters": {
**input_payload["parameters"],
"context_worker_id": os.environ["VLLM_WORKER_ID"],
"first_token": result.outputs[0].token_ids[0],
"seq_len": len(result.prompt_token_ids),
......@@ -172,7 +180,7 @@ class PrefillStage:
yield {"outputs": {}, "error": str(e), "final": True}
class GenerateStage:
class GenerateStage(Stage):
def __init__(
self,
**kwargs,
......@@ -230,28 +238,3 @@ class GenerateStage:
}
counter += 1
LOGGER.debug("results_generator finished for generate")
class DisaggregatedPipeline:
def __init__(
self,
stage,
**kwargs,
):
if stage == "prefill":
LOGGER.info(f"initialize prefill {kwargs}")
self.stage = PrefillStage(**kwargs) # type: ignore
elif stage == "generate":
LOGGER.info(f"initialize generate {kwargs}")
self.stage = GenerateStage(**kwargs) # type: ignore
else:
raise ValueError(f"Unknown stage: {stage}")
async def __call__(
self, input_payload: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
LOGGER.debug("Start pipeline")
async for result in self.stage(input_payload):
LOGGER.debug("yield result")
yield result
LOGGER.debug("Pipeline generator finished")
import argparse
import json
import logging
from dataclasses import field
from typing import Any, Optional
from typing import AsyncGenerator, List, Optional
import numpy as np
from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.worker import Operator, RemoteInferenceRequest
from .vllm_disaggregated.pipelines import (
GenerateStage,
PrefillStage,
SingleComputePipeline,
from triton_distributed.worker import (
Operator,
RemoteInferenceRequest,
RemoteInferenceResponse,
RemoteOperator,
)
from .vllm_disaggregated.stage_executor import PiplineStageExecutor
from .stages import AggregatedStage, GenerateStage, PrefillStage, Stage
class VllmContextOperator(Operator):
class VllmOperator(Operator):
def __init__(
self,
name: str,
......@@ -26,13 +30,49 @@ class VllmContextOperator(Operator):
default_factory=dict
),
repository: Optional[str] = None,
logger: Optional[Any] = None,
logger: Optional[logging.Logger] = None,
):
self.name = name
self.version = version
self.request_plane = request_plane
self.data_plane = data_plane
if logger is None:
self.logger = logging.getLogger(__name__)
else:
self.logger = logger
self._stage: Stage
self._init_stages(parameters)
async def execute(self, requests: List[RemoteInferenceRequest]) -> None:
for request in requests:
response_sender = request.response_sender()
try:
inputs, parameters = self._prepare_inputs(request)
self.logger.debug("Processing request")
async for response in self._stage(
{
"inputs": inputs,
"parameters": parameters,
}
):
self.logger.debug("Sending response")
await response_sender.send(**response)
self.logger.debug("Response send")
except Exception as e:
self.logger.error(f"Error processing request: {e}")
await response_sender.send(error=e, final=True)
def _init_stages(
self,
parameters: Optional[dict[str, str | int | bool | bytes]] = field(
default_factory=dict
),
):
args = argparse.Namespace(**parameters) # type: ignore
stage = PrefillStage(
self._stage = AggregatedStage(
model=args.model_name,
tensor_parallel_size=args.context_tp_size,
generate_tensor_parallel_size=args.generate_tp_size,
tensor_parallel_size=args.baseline_tp_size,
gpu_memory_utilization=args.gpu_memory_utilization,
max_model_len=args.max_model_len,
dtype=args.dtype,
......@@ -45,33 +85,32 @@ class VllmContextOperator(Operator):
disable_async_output_proc=args.disable_async_output_proc,
disable_log_stats=args.disable_log_stats,
)
self.executor = PiplineStageExecutor(
args, request_plane, stage, "prefill", "generate"
)
async def execute(self, requests: list[RemoteInferenceRequest]) -> None:
await self.executor.process_requests(requests)
@staticmethod
def _prepare_inputs(request: RemoteInferenceRequest):
inputs, parameters = {}, {}
for input_name, input_data in request.inputs.items():
inputs[input_name] = np.from_dlpack(input_data)
for key, value in request.parameters.items():
if isinstance(value, str) and value.startswith("JSON:"):
parameters[key] = json.loads(value[5:])
else:
parameters[key] = value
return inputs, parameters
class VllmGenerateOperator(Operator):
def __init__(
class VllmContextOperator(VllmOperator):
def _init_stages(
self,
name: str,
version: int,
triton_core,
request_plane: RequestPlane,
data_plane: DataPlane,
parameters: Optional[dict[str, str | int | bool | bytes]] = field(
default_factory=dict
),
repository: Optional[str] = None,
logger: Optional[Any] = None,
):
args = argparse.Namespace(**parameters) # type: ignore
args.worker_name = "generate"
stage = GenerateStage(
self._prefill_stage = PrefillStage(
model=args.model_name,
tensor_parallel_size=args.generate_tp_size,
tensor_parallel_size=args.context_tp_size,
generate_tensor_parallel_size=args.generate_tp_size,
gpu_memory_utilization=args.gpu_memory_utilization,
max_model_len=args.max_model_len,
dtype=args.dtype,
......@@ -84,30 +123,62 @@ class VllmGenerateOperator(Operator):
disable_async_output_proc=args.disable_async_output_proc,
disable_log_stats=args.disable_log_stats,
)
self.executor = PiplineStageExecutor(args, request_plane, stage, "generate")
self._generate_operator = RemoteOperator(
"generate", self.request_plane, self.data_plane
)
async def execute(self, requests: list[RemoteInferenceRequest]) -> None:
await self.executor.process_requests(requests)
async def execute(self, requests: List[RemoteInferenceRequest]) -> None:
for request in requests:
response_sender = request.response_sender()
try:
self.logger.info("Processing request")
inputs, parameters = self._prepare_inputs(request)
responses = [
response
async for response in self._prefill_stage(
{
"inputs": inputs,
"parameters": parameters,
}
)
]
self.logger.info("Prefill finished")
assert len(responses) == 1
response = responses[0]
self.logger.info("Processing generate")
generate_responses: AsyncGenerator[
RemoteInferenceResponse, None
] = await self._generate_operator.async_infer(
inputs=response["outputs"],
parameters={**request.parameters, **response["parameters"]},
)
async for generate_response in generate_responses:
self.logger.info("Sending response")
parameters = {"text": generate_response.parameters["text"]}
await response_sender.send(
outputs=generate_response.outputs,
parameters=parameters,
final=generate_response.final,
error=generate_response.error,
)
self.logger.info("Response send")
except Exception as e:
self.logger.error(f"Error processing request: {e}")
await response_sender.send(error=e, final=True)
class VllmBaselineOperator(Operator):
def __init__(
class VllmGenerateOperator(VllmOperator):
def _init_stages(
self,
name: str,
version: int,
triton_core,
request_plane: RequestPlane,
data_plane: DataPlane,
parameters: Optional[dict[str, str | int | bool | bytes]] = field(
default_factory=dict
),
repository: Optional[str] = None,
logger: Optional[Any] = None,
):
args = argparse.Namespace(**parameters) # type: ignore
stage = SingleComputePipeline(
args.worker_name = "generate"
self._stage = GenerateStage(
model=args.model_name,
tensor_parallel_size=args.baseline_tp_size,
tensor_parallel_size=args.generate_tp_size,
gpu_memory_utilization=args.gpu_memory_utilization,
max_model_len=args.max_model_len,
dtype=args.dtype,
......@@ -120,7 +191,3 @@ class VllmBaselineOperator(Operator):
disable_async_output_proc=args.disable_async_output_proc,
disable_log_stats=args.disable_log_stats,
)
self.executor = PiplineStageExecutor(args, request_plane, stage, "baseline")
async def execute(self, requests: list[RemoteInferenceRequest]) -> None:
await self.executor.process_requests(requests)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import abc
import dataclasses
import typing
class TritonInferenceError(Exception):
"""Error occurred during Triton inference."""
@dataclasses.dataclass
class InferenceRequest:
"""Inference request."""
inputs: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
parameters: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
@dataclasses.dataclass
class InferenceResponse:
"""Inference response."""
outputs: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
error: typing.Optional[str] = None
final: bool = False
parameters: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
class BaseTriton3Connector(abc.ABC):
"""Base class for Triton 3 connector."""
@abc.abstractmethod
def inference(
self, model_name: str, request: InferenceRequest
) -> typing.AsyncGenerator[InferenceResponse, None]:
"""Inference request to Triton 3 system.
Args:
model_name: Model name.
request: Inference request.
Returns:
Inference response.
Raises:
TritonInferenceError: error occurred during inference.
"""
raise NotImplementedError
async def list_models(self) -> typing.List[str]:
"""List models available in Triton 3 system.
Returns:
List of model names.
"""
raise NotImplementedError
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from typing import Optional
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane
# UCP data plane causes deadlocks when used more than once, so we use a singleton
_g_singletonic_data_plane = None
_g_singletonic_data_plane_connection_count = 0
_g_actual_host = None
_g_actual_port = None
def set_actual_host_port(host, port):
global _g_actual_host
global _g_actual_port
if _g_singletonic_data_plane is not None:
raise Exception("Cannot set actual host and port after data plane is created")
_g_actual_host = host
_g_actual_port = port
def set_data_plane(data_plane):
global _g_singletonic_data_plane
global _g_singletonic_data_plane_connection_count
_g_singletonic_data_plane_connection_count = 1
_g_singletonic_data_plane = data_plane
class RemoteConnector:
"""Handle connection to both request and data planes."""
def __init__(
self,
request_plane: RequestPlane,
data_plane_host: Optional[str] = None,
data_plane_port: int = 0,
keep_dataplane_endpoints_open: bool = False,
):
"""Initialize RemoteConnector.
Args:
nats_url (str): URL of NATS server.
"""
global _g_singletonic_data_plane
global _g_actual_port
global _g_actual_host
self._request_plane = request_plane
if _g_singletonic_data_plane is None:
if _g_actual_host is not None:
data_plane_host = _g_actual_host
if _g_actual_port is not None:
data_plane_port = _g_actual_port
_g_singletonic_data_plane = UcpDataPlane(
hostname=data_plane_host,
port=data_plane_port,
keep_endpoints_open=keep_dataplane_endpoints_open,
)
self._connected = False
self._data_plane = _g_singletonic_data_plane
async def connect(self):
"""Connect to both request and data planes."""
global _g_singletonic_data_plane
global _g_singletonic_data_plane_connection_count
assert _g_singletonic_data_plane
if _g_singletonic_data_plane_connection_count == 0:
_g_singletonic_data_plane.connect()
_g_singletonic_data_plane_connection_count += 1
self._connected = True
async def close(self):
"""Disconnect from both request and data planes."""
global _g_singletonic_data_plane
global _g_singletonic_data_plane_connection_count
assert _g_singletonic_data_plane
await self._request_plane.close()
_g_singletonic_data_plane_connection_count -= 1
if _g_singletonic_data_plane_connection_count == 0:
_g_singletonic_data_plane.close()
_g_singletonic_data_plane = None
self._data_plane.close()
self._connected = False
async def __aenter__(self):
"""Enter context manager."""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
"""Exit context manager."""
await self.close()
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import asyncio
import json
import typing
from typing import Any, Coroutine, List, Optional
import numpy as np
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.worker.remote_operator import RemoteOperator
from triton_distributed.worker.remote_response import AsyncRemoteResponseIterator
from triton_distributed.worker.remote_tensor import RemoteTensor
from .connector import BaseTriton3Connector, InferenceRequest, InferenceResponse
from .remote_connector import RemoteConnector
class RemoteModelConnector(BaseTriton3Connector):
"""Connector for Triton 3 model."""
def __init__(
self,
request_plane: RequestPlane,
model_name: str,
model_version: Optional[str] = None,
data_plane_host: Optional[str] = None,
data_plane_port: int = 0,
keep_dataplane_endpoints_open: bool = False,
):
"""Initialize Triton 3 connector.
Args:
nats_url: NATS URL (e.g. "localhost:4222").
model_name: Model name.
model_version: Model version. Default is "1".
data_plane_host: Data plane host (e.g. "localhost").
data_plane_port: Data plane port (e.g. 8001). You can use 0 to let the system choose a port.
keep_dataplane_endpoints_open: Keep data plane endpoints open to avoid reconnecting. Default is False.
Example:
remote_model_connector = RemoteModelConnector(
nats_url="localhost:4222",
data_plane_host="localhost",
data_plane_port=8001,
model_name="model_name",
model_version="1",
)
async with remote_model_connector:
request = InferenceRequest(inputs={"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])})
async for response in remote_model_connector.inference(model_name="model_name", request=request):
print(response.outputs)
"""
self._model = None
self._connector = RemoteConnector(
request_plane,
data_plane_host,
data_plane_port,
keep_dataplane_endpoints_open=keep_dataplane_endpoints_open,
)
self._model_name = model_name
if model_version is None:
model_version = "1"
self._model_version = model_version
async def connect(self):
"""Connect to Triton 3 server."""
await self._connector.connect()
self._model = RemoteOperator(
operator=self._model_name,
request_plane=self._connector._request_plane,
data_plane=self._connector._data_plane,
)
async def close(self):
"""Disconnect from Triton 3 server."""
await self._connector.close()
self._model = None
async def __aenter__(self):
"""Enter context manager."""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
"""Exit context manager."""
await self.close()
async def inference(
self, model_name: str, request: InferenceRequest
) -> typing.AsyncGenerator[InferenceResponse, None]:
"""Inference request to Triton 3 system.
Args:
model_name: Model name.
request: Inference request.
Returns:
Inference response.
Raises:
TritonInferenceError: error occurred during inference.
"""
if not self._connector._connected or self._model is None:
await self.connect()
else:
if self._model_name != model_name:
self._model_name = model_name
self._model_version = "1"
self._model = RemoteOperator(
operator=self._model_name,
request_plane=self._connector._request_plane,
data_plane=self._connector._data_plane,
)
results: List[Coroutine[Any, Any, AsyncRemoteResponseIterator]] = []
for key, value in request.parameters.items():
if isinstance(value, dict):
request.parameters[key] = "JSON:" + json.dumps(value)
assert self._model is not None
results.append(
self._model.async_infer(
inputs=request.inputs,
parameters=request.parameters,
)
)
for result in asyncio.as_completed(results):
responses = await result
async for response in responses:
triton_response = response.to_model_infer_response(
self._connector._data_plane
)
outputs = {}
for output in triton_response.outputs:
remote_tensor = RemoteTensor(output, self._connector._data_plane)
try:
local_tensor = remote_tensor.local_tensor
numpy_tensor = np.from_dlpack(local_tensor)
finally:
# FIXME: This is a workaround for the issue that the remote tensor
# is released after connection is closed.
remote_tensor.__del__()
outputs[output.name] = numpy_tensor
infer_response = InferenceResponse(
outputs=outputs,
final=response.final,
parameters=response.parameters,
)
yield infer_response
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, Optional, Tuple
import numpy as np
from pydantic import BaseModel
from tritonserver import Tensor as TritonTensor
from tritonserver._api._response import InferenceResponse as TritonInferenceResponse
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.worker.remote_request import RemoteInferenceRequest
from triton_distributed.worker.remote_response import RemoteInferenceResponse
from .remote_connector import RemoteConnector
class LocalModel(BaseModel):
name: str
version: str
class RequestConverter:
"""Request converter. Class converts requests to convenient format for processing."""
def __init__(
self,
request_plane: RequestPlane,
model_name: str,
data_plane_host: Optional[str] = None,
data_plane_port: int = 0,
keep_dataplane_endpoints_open: bool = False,
):
"""Initialize RequestAdapter.
Args:
nats_url: NATS URL (e.g. "localhost:4222").
data_plane_host: Data plane host (e.g. "localhost").
data_plane_port: Data plane port (e.g. 8001). You can use 0 to let the system choose a port.
keep_dataplane_endpoints_open: Keep data plane endpoints open to avoid reconnecting. Default is False.
Example for async model:
worker = RequestConverter(
nats_url="localhost:4222",
data_plane_host="localhost",
data_plane_port=8001,
)
async with worker:
# This flow will process 10 requests at a time
processors = []
async def processing(request, callable):
request, callable = await queue.get()
inputs = request["inputs"]
parameters = request["parameters"]
output_tensor = inputs["a"] + inputs["b"]
try:
await callable({"c": output_tensor})
for _ in range(parameters["increment"]):
output_tensor += 1
await callable({"c": output_tensor})
finally:
await callable({"c": output_tensor}, final=True)
async for request, callable in worker.pull(model_name="model_name", batch_size=10):
# Check if batch size was reached
if len(processors) >= 10:
done, pending = asyncio.wait(processors, return_when=asyncio.FIRST_COMPLETED)
processors = list(pending)
processors.append(processing(request, callable))
"""
self._connector = RemoteConnector(
request_plane,
data_plane_host,
data_plane_port,
keep_dataplane_endpoints_open=keep_dataplane_endpoints_open,
)
self._local_model = LocalModel(name=model_name, version="1")
async def connect(self):
"""Connect to Triton 3 server."""
await self._connector.connect()
async def close(self):
"""Disconnect from Triton 3 server."""
await self._connector.close()
async def __aenter__(self):
"""Enter context manager."""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
"""Exit context manager."""
await self.close()
async def pull(
self,
model_name: str,
model_version: Optional[str] = None,
batch_size: Optional[int] = None,
timeout: Optional[float] = None,
) -> AsyncGenerator[
Tuple[Dict[str, Any], Callable[[Dict[str, Any]], Awaitable[None]]], None
]:
"""Pull requests from request plane and data plane.
Pull returns an async generator that yields a tuple of request and callable.
Request contains inputs and parameters. Inputs are a dictionary of input names and numpy arrays. Parameters are
a dictionary of scalar parameters like sampling parameters in language models.
Callable is a function that takes outputs, error and final as arguments. Outputs are a dictionary of output names
and numpy arrays. Error is Exception. Final is a boolean that indicates if the response is final.
Args:
model_name: Model name.
model_version: Model version. Default is "1".
batch_size: Batch size. Default is 1.
timeout: Max duration of the pull request before it expires. Default is None.
Returns:
Inference request and callable.
Example:
worker = PythonWorkerConnector(
nats_url="localhost:4222",
data_plane_host="localhost",
data_plane_port=8001,
)
asyn with worker:
# This flow will process single request at a time
async for request, callable in worker.pull(model_name="model_name"):
# This is siple add model with incrementing the input tensor by increment parameter
inputs = request["inputs"]
parameters = request["parameters"]
output_tensor = inputs["a"] + inputs["b"]
try:
await callable({"c": output_tensor})
for _ in range(parameters["increment"]):
output_tensor += 1
await callable({"c": output_tensor})
finally:
await callable({"c": output_tensor}, final=True)
"""
if not self._connector._connected:
await self.connect()
if model_version is None:
model_version = "1"
if batch_size is None:
batch_size = 1
local_model = LocalModel(
name=model_name,
version=model_version,
)
kwargs = {
"model_name": model_name,
"model_version": model_version,
"number_requests": batch_size,
}
if timeout is not None:
kwargs["timeout"] = timeout
while True:
requests_iterator = await self._connector._request_plane.pull_requests(
**kwargs
)
async for request in requests_iterator:
inputs, remote_request, return_callable = await self.adapt_request(
request, local_model
)
yield {
"inputs": inputs,
"parameters": remote_request.parameters,
}, return_callable
async def adapt_request(self, request, local_model: Optional[LocalModel] = None):
if local_model is None:
local_model = self._local_model
if isinstance(request, RemoteInferenceRequest):
remote_request = request
request = remote_request.to_model_infer_request()
else:
remote_request = RemoteInferenceRequest.from_model_infer_request(
request,
self._connector._data_plane,
self._connector._request_plane,
)
def produce_callable(request):
async def return_callable(
outputs: Dict[str, Any],
parameters: Optional[Dict[str, Any]] = None,
error: Optional[str] = None,
final: Optional[bool] = False,
) -> None:
request_id = request.parameters["icp_request_id"].string_param
infer_kwargs = {
"model": local_model,
"request_id": request_id,
}
if error is not None:
infer_kwargs["error"] = error
else:
outputs_tensors = {}
for name, value in outputs.items():
outputs_tensors[name] = TritonTensor.from_dlpack(value)
infer_kwargs["outputs"] = outputs_tensors
if final is not None:
infer_kwargs["final"] = final
if parameters is not None:
infer_kwargs["parameters"] = parameters
local_response = TritonInferenceResponse(**infer_kwargs)
remote_response = RemoteInferenceResponse.from_local_response(
local_response,
).to_model_infer_response(self._connector._data_plane)
# FIXME: This is a WAR for scenario where connector isn't
# connected when posting a response to request plane.
if not self._connector._connected:
await self.connect()
await self._connector._request_plane.post_response(
request,
remote_response,
)
return return_callable
return_callable = produce_callable(request)
inputs = {}
for name, input_tensor in remote_request.inputs.items():
local_tensor = input_tensor.local_tensor
numpy_tensor = np.from_dlpack(local_tensor)
input_tensor.__del__()
inputs[name] = numpy_tensor
for key, value in remote_request.parameters.items():
if isinstance(value, str) and value.startswith("JSON:"):
remote_request.parameters[key] = json.loads(value[5:])
return inputs, remote_request, return_callable
import asyncio
import enum
import logging
import os
from contextlib import nullcontext
import torch
from .connector import InferenceRequest
from .remote_model_connector import RemoteModelConnector
from .request_converter import RequestConverter
LOGGER = logging.getLogger(__name__)
class _ProfileState(enum.Enum):
NOT_STARTED = 0
STARTED = 1
STOPPED = 2
class PiplineStageExecutor:
def __init__(self, args, request_plane, stage, stage_name, next_stage_name=None):
self.args = args
self.stage = stage
self.stage_name = stage_name
self.is_context_stage = next_stage_name is not None
self.next_stage_name = next_stage_name
self.remote_model_connector = (
RemoteModelConnector(
request_plane=request_plane,
model_name=self.next_stage_name,
keep_dataplane_endpoints_open=True,
)
if self.is_context_stage
else None
)
self.request_converter = RequestConverter(
request_plane=request_plane,
keep_dataplane_endpoints_open=True,
model_name=self.args.worker_name,
)
self.request_counter = 0
self.profile_state = _ProfileState.NOT_STARTED
self.tasks = []
async def baseline_process(self, request, return_result):
try:
LOGGER.debug("Processing request")
async for response in self.stage(request):
LOGGER.debug("Sending response")
await return_result(**response)
LOGGER.debug("Response send")
except Exception as e:
LOGGER.error(f"Error processing request: {e}")
await return_result({"error": e, "final": True})
LOGGER.debug("Processing finished")
async def process(self, request, return_result):
LOGGER.debug("Processing request")
try:
LOGGER.debug(f"Stage {self.stage_name} execution")
responses = list([response async for response in self.stage(request)])
LOGGER.debug(f"Stage {self.stage_name} finished")
assert len(responses) == 1
response = responses[0]
parameters = response.get("parameters", {})
if not parameters:
raise RuntimeError(
f"ERROR: Response parameters from stage {self.stage_name} should not be empty!"
)
outputs = response.get("outputs", {})
request = InferenceRequest(inputs=outputs, parameters=parameters)
LOGGER.info(f"Next stage {self.next_stage_name} execution")
assert self.remote_model_connector is not None
async for response in self.remote_model_connector.inference(
model_name=self.next_stage_name, request=request
):
LOGGER.debug(f"Stage {self.stage_name} sending response")
await return_result(
outputs=response.outputs,
final=response.final,
parameters={"text": response.parameters["text"]},
)
LOGGER.debug(f"Stage {self.stage_name} sended response")
except Exception as e:
LOGGER.error(f"Error processing request: {e}", exc_info=True)
await return_result(outputs={}, error=e, final=True)
async def handle_pipelined_requests(self):
LOGGER.info(
f"Start handling requests stage_name {self.stage_name} args {self.args}"
)
async with self.request_converter, self.remote_model_connector or nullcontext():
LOGGER.info(f"Stage {self.stage_name} starts pulling")
async for request, return_result in self.request_converter.pull(
model_name=self.args.worker_name
):
# TODO ptarasiewicz - only one context or generate should be profiled at a time
await self.process_request(request, return_result)
LOGGER.info(f"Stage {self.stage_name} finished pulling")
async def process_requests(self, requests):
for raw_request in requests:
(
inputs,
remote_request,
return_callable,
) = await self.request_converter.adapt_request(raw_request)
request, return_result = {
"inputs": inputs,
"parameters": remote_request.parameters,
}, return_callable
await self.process_request(request, return_result)
async def process_request(self, request, return_result):
self._profile()
if self.is_context_stage:
process_function = self.process
else:
process_function = self.baseline_process
# self.request_counter += 1
LOGGER.debug(f"Stage {self.stage_name} pulled request")
self.tasks.append(asyncio.create_task(process_function(request, return_result)))
if len(self.tasks) >= self.args.max_batch_size:
LOGGER.debug(
f"Stage {self.stage_name} waiting some of {len(self.tasks)} requests to finish"
)
_, pending = await asyncio.wait(
self.tasks, return_when=asyncio.FIRST_COMPLETED
)
self.tasks = list(pending)
LOGGER.debug(
f"Stage {self.stage_name} finished some requests with {len(self.tasks)} to do"
)
def _profile(self):
if os.environ.get("RUN_PROFILING") == "1":
if (
self.profile_state == _ProfileState.NOT_STARTED
and self.request_counter > 100
):
LOGGER.info("Start profiling")
torch.cuda.profiler.start()
self.profile_state = _ProfileState.STARTED
elif (
self.profile_state == _ProfileState.STARTED
and self.request_counter > 120
):
LOGGER.info("Stop profiling")
torch.cuda.profiler.stop()
self.profile_state = _ProfileState.STOPPED
# can also use with torch.cuda.profiler.profile():
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment